<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> cherry pick https://github.com/vllm-project/vllm-ascend/pull/3677 Remove redundant operations from `model_runner` and `forward_context`. This optimization can significantly reduce the idle time (bubble) before decoding when running models with small parameter counts (e.g., Qwen/Qwen2.5-0.5B). Testing on 800I A2, bubble is reduced from 3.8ms to 2.8ms : Before <img width="1655" height="696" alt="image" src="https://github.com/user-attachments/assets/d7608e52-2438-46dd-8fc9-391fd6274495" /> After <img width="1607" height="774" alt="image" src="https://github.com/user-attachments/assets/56daf081-2dba-4d2e-99d4-e055187d9806" /> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> No ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com>
3673 lines
175 KiB
Python
3673 lines
175 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
|
#
|
|
|
|
import copy
|
|
import gc
|
|
import itertools
|
|
import math
|
|
import re
|
|
import time
|
|
from collections import defaultdict
|
|
from collections.abc import Iterator
|
|
from contextlib import contextmanager, nullcontext
|
|
from copy import deepcopy
|
|
from dataclasses import dataclass
|
|
from multiprocessing import Manager
|
|
from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional,
|
|
Union, cast)
|
|
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
import torch
|
|
import torch._dynamo.cache_size
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from tqdm import tqdm # type: ignore
|
|
from vllm.attention import AttentionType, get_attn_backend
|
|
from vllm.attention.backends.abstract import AttentionBackend
|
|
from vllm.attention.layer import Attention
|
|
from vllm.compilation.counter import compilation_counter
|
|
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
|
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
|
|
get_layers_from_vllm_config)
|
|
from vllm.distributed import tensor_model_parallel_all_gather
|
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|
has_kv_transfer_group)
|
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
|
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
|
|
get_tp_group,
|
|
is_global_first_rank)
|
|
from vllm.forward_context import BatchDescriptor, get_forward_context
|
|
from vllm.logger import logger
|
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
|
from vllm.model_executor.model_loader import get_model
|
|
from vllm.model_executor.models.interfaces import supports_transcription
|
|
from vllm.model_executor.models.interfaces_base import (
|
|
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
|
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
|
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
|
from vllm.pooling_params import PoolingParams
|
|
from vllm.sampling_params import SamplingType
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|
LazyLoader, cdiv, get_dtype_size,
|
|
is_pin_memory_available)
|
|
from vllm.utils.jsontree import json_map_leaves
|
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
|
from vllm.v1.attention.backends.utils import (
|
|
AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills)
|
|
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
|
# yapf conflicts with isort for this block
|
|
# yapf: disable
|
|
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
|
EncoderOnlyAttentionSpec,
|
|
FullAttentionSpec, KVCacheConfig,
|
|
KVCacheGroupSpec, KVCacheSpec,
|
|
MambaSpec, MLAAttentionSpec,
|
|
UniformTypeKVCacheSpecs)
|
|
# yapf: enable
|
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
|
DraftTokenIds, LogprobsTensors, ModelRunnerOutput,
|
|
PoolerOutput)
|
|
from vllm.v1.pool.metadata import PoolingMetadata
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
|
from vllm.v1.utils import CpuGpuBuffer
|
|
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
|
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|
from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
|
|
gather_mm_placeholders,
|
|
sanity_check_mm_encoder_outputs,
|
|
scatter_mm_placeholders)
|
|
|
|
import vllm_ascend.envs as envs_ascend
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
from vllm_ascend.ascend_forward_context import (MoECommType,
|
|
set_ascend_forward_context)
|
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
|
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
|
set_graph_params,
|
|
update_attn_params,
|
|
update_mla_attn_params)
|
|
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
|
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
|
|
D2DExpertWeightLoader
|
|
from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils
|
|
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
|
from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
|
from vllm_ascend.eplb.utils import model_register
|
|
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
|
|
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
|
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
|
from vllm_ascend.platform import NPUPlatform
|
|
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
|
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
|
from vllm_ascend.spec_decode import get_spec_decode_method
|
|
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
|
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
|
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
|
AscendSocVersion, ProfileExecuteDuration,
|
|
enable_sp, get_ascend_soc_version, is_310p,
|
|
is_enable_nz, is_moe_model, lmhead_tp_enable)
|
|
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
|
|
|
if TYPE_CHECKING:
|
|
import xgrammar as xgr # type: ignore[import-untyped]
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
else:
|
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
|
|
|
import torch_npu
|
|
|
|
# if true, allow tensor initialization and casting with internal format (e.g., NZ)
|
|
torch.npu.config.allow_internal_format = True
|
|
|
|
if is_310p():
|
|
torch_npu.npu.set_compile_mode(jit_compile=False)
|
|
ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ
|
|
else:
|
|
ACL_FORMAT = ACL_FORMAT_FRACTAL_ND
|
|
|
|
|
|
@dataclass
|
|
class GraphCaptureContext:
|
|
stream: torch.npu.Stream
|
|
|
|
|
|
@contextmanager
|
|
def graph_capture(device: torch.device):
|
|
"""
|
|
`graph_capture` is a context manager which should surround the code that
|
|
is capturing the NPU graph. Its main purpose is to ensure that the
|
|
some operations will be run after the graph is captured, before the graph
|
|
is replayed. It returns a `GraphCaptureContext` object which contains the
|
|
necessary data for the graph capture. Currently, it only contains the
|
|
stream that the graph capture is running on. This stream is set to the
|
|
current NPU stream when the context manager is entered and reset to the
|
|
default stream when the context manager is exited. This is to ensure that
|
|
the graph capture is running on a separate stream from the default stream,
|
|
in order to explicitly distinguish the kernels to capture
|
|
from other kernels possibly launched on background in the default stream.
|
|
"""
|
|
graph_capture_context = GraphCaptureContext(
|
|
torch.npu.Stream(device=device))
|
|
stream = graph_capture_context.stream
|
|
|
|
# we use nullcontext now
|
|
maybe_ca_context = nullcontext()
|
|
|
|
# ensure all initialization operations complete before attempting to
|
|
# capture the graph on another stream
|
|
curr_stream = torch.npu.current_stream()
|
|
if curr_stream != stream:
|
|
stream.wait_stream(curr_stream)
|
|
|
|
with torch.npu.stream(stream), maybe_ca_context:
|
|
yield graph_capture_context
|
|
|
|
|
|
# Wrapper for ModelRunnerOutput to support overlapped execution.
|
|
class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput):
|
|
|
|
def __init__(
|
|
self,
|
|
model_runner_output: ModelRunnerOutput,
|
|
sampled_token_ids: torch.Tensor,
|
|
invalid_req_indices: list[int],
|
|
async_output_copy_stream: torch.npu.Stream,
|
|
):
|
|
self._model_runner_output = model_runner_output
|
|
self._invalid_req_indices = invalid_req_indices
|
|
|
|
# Event on the copy stream so we can synchronize the non-blocking copy.
|
|
self._async_copy_ready_event = torch.npu.Event()
|
|
|
|
# Keep a reference to the device tensor to avoid it being
|
|
# deallocated until we finish copying it to the host.
|
|
self._sampled_token_ids = sampled_token_ids
|
|
|
|
# Initiate the copy on a separate stream, but do not synchronize it.
|
|
default_stream = torch.npu.current_stream()
|
|
with torch.npu.stream(async_output_copy_stream):
|
|
async_output_copy_stream.wait_stream(default_stream)
|
|
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
|
|
'cpu', non_blocking=True)
|
|
self._async_copy_ready_event.record()
|
|
|
|
def get_output(self) -> ModelRunnerOutput:
|
|
"""Copy the device tensors to the host and return a ModelRunnerOutput.
|
|
|
|
This function blocks until the copy is finished.
|
|
"""
|
|
self._async_copy_ready_event.synchronize()
|
|
|
|
# Release the device tensor once the copy has completed
|
|
del self._sampled_token_ids
|
|
|
|
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
|
|
for i in self._invalid_req_indices:
|
|
valid_sampled_token_ids[i].clear()
|
|
|
|
output = self._model_runner_output
|
|
output.sampled_token_ids = valid_sampled_token_ids
|
|
return output
|
|
|
|
|
|
class NPUModelRunner(LoRAModelRunnerMixin):
|
|
|
|
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
self.cache_config = vllm_config.cache_config
|
|
self.compilation_config = vllm_config.compilation_config
|
|
self.load_config = vllm_config.load_config
|
|
self.lora_config = vllm_config.lora_config
|
|
self.parallel_config = vllm_config.parallel_config
|
|
self.pin_memory = is_pin_memory_available()
|
|
self.scheduler_config = vllm_config.scheduler_config
|
|
self.speculative_config = vllm_config.speculative_config
|
|
self.block_size = vllm_config.cache_config.block_size
|
|
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
|
self.block_size)
|
|
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
|
decode_max_num_seqs = getattr(self.scheduler_config,
|
|
'decode_max_num_seqs', 0)
|
|
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
|
|
decode_max_num_seqs)
|
|
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
|
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
|
self.device = device
|
|
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
|
|
self.prefetch_stream = torch.npu.Stream(device=device)
|
|
else:
|
|
self.prefetch_stream = None
|
|
self.dtype = self.model_config.dtype
|
|
if envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
|
|
# TODO: drop the env config to use ascend sampler by default
|
|
from vllm_ascend.sample.sampler import AscendSampler
|
|
|
|
self.sampler = AscendSampler()
|
|
else:
|
|
from vllm.v1.sample.sampler import Sampler
|
|
|
|
self.sampler = Sampler()
|
|
self.reorder_batch_threshold: Optional[int] = None
|
|
|
|
# Lazy initialization, these will be set after __init__
|
|
self.kv_caches: List[torch.Tensor] = []
|
|
self.attn_groups: list[list[AttentionGroup]] = []
|
|
self.encoder_cache: Dict[str, torch.Tensor] = {}
|
|
self.attn_mask = None
|
|
self.attn_state = None
|
|
self.requests: Dict[str, CachedRequestState] = {}
|
|
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
|
self.runner_only_attn_layers: set[str] = set()
|
|
|
|
self.ascend_config = get_ascend_config()
|
|
if self.ascend_config.ascend_scheduler_config.enabled:
|
|
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
|
|
else:
|
|
self.chunked_prefill_enabled = True
|
|
self.weight_prefetch_method = WeightPrefetchMethod(
|
|
self.ascend_config.weight_prefetch_config)
|
|
|
|
if self.cache_config.cache_dtype == "auto":
|
|
self.kv_cache_dtype = self.dtype
|
|
else:
|
|
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
|
self.cache_config.cache_dtype]
|
|
# use_hybrid_blocks: if hybrid blocks is used.
|
|
self.use_hybrid_blocks: bool = False
|
|
self.need_accepted_tokens: bool = False
|
|
|
|
self.is_multimodal_model = self.model_config.is_multimodal_model
|
|
self.is_pooling_model = self.model_config.pooler_config is not None
|
|
if self.is_multimodal_model:
|
|
self.inputs_embeds = torch.zeros(
|
|
(self.max_num_tokens, self.model_config.get_hidden_size()),
|
|
dtype=self.dtype,
|
|
device=self.device)
|
|
# Set up Attention
|
|
self.use_sparse = hasattr(self.vllm_config.model_config.hf_config,
|
|
"index_topk")
|
|
self.attn_backend = get_attn_backend(0,
|
|
self.dtype,
|
|
None,
|
|
self.block_size,
|
|
use_mla=self.model_config.use_mla,
|
|
use_sparse=self.use_sparse)
|
|
if torch.version.cann.startswith("8.3"):
|
|
self.attn_mask_builder = AttentionMaskBuilder(
|
|
self.scheduler_config.max_num_batched_tokens, self.dtype,
|
|
self.device)
|
|
else:
|
|
self.attn_mask_builder = AttentionMaskBuilder(
|
|
self.model_config.max_model_len, self.dtype)
|
|
|
|
# Set up speculative decoding.
|
|
self.spec_attn_mask = None
|
|
self.drafter: Optional[Union[NgramProposer, EagleProposer,
|
|
MtpProposer]] = None
|
|
self.actual_seq_lengths_q: list[int] = []
|
|
self.decode_token_per_req = 1
|
|
if self.speculative_config:
|
|
spec_token_num = self.speculative_config.num_speculative_tokens
|
|
assert spec_token_num > 0
|
|
self.decode_token_per_req = 1 + spec_token_num
|
|
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
|
2048,
|
|
dtype=torch.bool),
|
|
diagonal=1).to(self.device)
|
|
if get_pp_group().is_last_rank:
|
|
self.drafter = get_spec_decode_method(
|
|
self.speculative_config.method, self.vllm_config,
|
|
self.device, self)
|
|
self.rejection_sampler = AscendRejectionSampler()
|
|
self.actual_seq_lengths_q = list(
|
|
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
|
self.decode_token_per_req))
|
|
|
|
# kv role
|
|
self.is_kv_producer = False
|
|
self.is_kv_consumer = False
|
|
if vllm_config.kv_transfer_config is not None:
|
|
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
|
|
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
|
|
|
|
self._may_pad_kv_consumer_num_seq()
|
|
|
|
# Persistent batch.
|
|
self.input_ids = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
self.positions = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int64,
|
|
device=self.device)
|
|
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
self.seq_lens = torch.zeros(self.max_num_reqs,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
self.slot_mapping = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
|
|
if self.vllm_config.model_config.use_mla and \
|
|
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
|
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
|
self.cos = torch.ones(self.max_num_reqs *
|
|
self.decode_token_per_req,
|
|
1,
|
|
1,
|
|
rope_dim,
|
|
dtype=self.dtype,
|
|
device=self.device)
|
|
self.sin = torch.zeros(self.max_num_reqs *
|
|
self.decode_token_per_req,
|
|
1,
|
|
1,
|
|
rope_dim,
|
|
dtype=self.dtype,
|
|
device=self.device)
|
|
else:
|
|
self.cos = None
|
|
self.sin = None
|
|
|
|
self.uses_mrope = self.model_config.uses_mrope
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
if self.uses_mrope:
|
|
# NOTE: `mrope_positions` is implemented with one additional dummy
|
|
# position on purpose to make it non-contiguous so that it can work
|
|
# with torch compile.
|
|
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
|
|
|
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
|
# the modality of inputs. For text-only inputs, each dimension has
|
|
# identical position IDs, making M-RoPE functionally equivalent to
|
|
# 1D-RoPE.
|
|
# See page 5 of https://arxiv.org/abs/2409.12191
|
|
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
|
|
dtype=torch.int64,
|
|
device=self.device)
|
|
self.mrope_positions_cpu = torch.zeros(
|
|
(3, self.max_num_tokens + 1),
|
|
dtype=torch.int64,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
|
|
|
|
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
|
self.arange_np: npt.NDArray[np.int32] = np.arange(max(
|
|
self.max_num_reqs + 1, self.model_config.max_model_len,
|
|
self.max_num_tokens),
|
|
dtype=np.int32)
|
|
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
|
# a faster version of creating a new tensor every time. Thus, we should
|
|
# not make any assumptions about the values in these tensors.
|
|
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.positions_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int64,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.positions_np = self.positions_cpu.numpy()
|
|
|
|
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
|
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
|
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
|
|
|
self.use_aclgraph = self._use_aclgraph()
|
|
self.aclgraph_batch_sizes = list(
|
|
reversed(self.compilation_config.cudagraph_capture_sizes))
|
|
|
|
self.uniform_decode_query_len = 1 if not self.speculative_config else \
|
|
1 + self.speculative_config.num_speculative_tokens
|
|
# aclgraph dispatcher for runtime aclgraph dispatching.
|
|
self.aclgraph_dispatcher = CudagraphDispatcher(self.vllm_config)
|
|
# Cached outputs.
|
|
self._draft_token_ids: Optional[Union[list[list[int]],
|
|
torch.Tensor]] = None
|
|
|
|
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
|
|
self.in_profile_run = False
|
|
|
|
self._init_mc2_tokens_capacity()
|
|
if is_moe_model(vllm_config):
|
|
self.reserved_mc2_mask = torch.zeros(
|
|
self.mc2_tokens_capacity,
|
|
dtype=torch.bool,
|
|
device=self.device,
|
|
)
|
|
else:
|
|
self.reserved_mc2_mask = None
|
|
self.dynamic_eplb = self.ascend_config.dynamic_eplb or self.ascend_config.expert_map_record_path
|
|
if self.dynamic_eplb:
|
|
EPLBParamUtils.check_dynamic_eplb(self.ascend_config.dynamic_eplb)
|
|
EPLBParamUtils.check_expert_map_record_path(
|
|
self.ascend_config.expert_map_record_path)
|
|
self.is_eplb_warmuped = False
|
|
self.policy_type = self.ascend_config.eplb_policy_type
|
|
self.eplb_loader = D2DExpertWeightLoader()
|
|
self.manager = Manager()
|
|
self.shared_dict = self.manager.dict({
|
|
"expert_map": None,
|
|
"moe_load": None,
|
|
"expert_maps": None
|
|
})
|
|
self.eplb_process = EplbProcess(shared_dict=self.shared_dict,
|
|
policy_type=self.policy_type,
|
|
enable_d2d=True)
|
|
self.process = self.eplb_process._launch_process()
|
|
ascend_config = get_ascend_config()
|
|
self.eplb_updator = EplbUpdator(ascend_config, self.eplb_loader,
|
|
self.eplb_process, self.process)
|
|
|
|
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
|
self.async_output_copy_stream = torch.npu.Stream() if \
|
|
self.use_async_scheduling else None
|
|
# Input Batch
|
|
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
|
# `initialize_kv_cache` based on the kv cache config. However, as in
|
|
# https://github.com/vllm-project/vllm/pull/18298, due to some unknown
|
|
# reasons, we have to initialize the input batch before `load_model`,
|
|
# quantization + weight offloading will fail otherwise. As a temporary
|
|
# solution, we initialize the input batch here, and re-initialize it
|
|
# in `initialize_kv_cache` if the block_sizes here is different from
|
|
# the block_sizes in the kv cache config.
|
|
self.input_batch = InputBatch(
|
|
max_num_reqs=self.max_num_reqs,
|
|
max_model_len=self.model_config.max_model_len,
|
|
max_num_batched_tokens=self.max_num_tokens,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
vocab_size=self.model_config.get_vocab_size(),
|
|
block_sizes=[self.block_size],
|
|
is_spec_decode=bool(self.vllm_config.speculative_config),
|
|
logitsprocs=build_logitsprocs(
|
|
self.vllm_config, self.device, self.pin_memory,
|
|
self.is_pooling_model,
|
|
self.vllm_config.model_config.logits_processors),
|
|
is_pooling_model=self.is_pooling_model,
|
|
kernel_block_sizes=[[self.vllm_config.cache_config.block_size]],
|
|
)
|
|
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
|
|
dtype=torch.int64)
|
|
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
|
|
dtype=torch.int32)
|
|
|
|
def _may_pad_kv_consumer_num_seq(self):
|
|
# For Full Graph + MTP in a PD (Prefill/Decode) disaggregation scenario,
|
|
# we may want to pad self.max_num_seqs in kv_consumer nodes to avoid
|
|
# exceeding a sequence length limit (16 tokens) in npu_fused_infer_attention_score operation
|
|
pass
|
|
|
|
def _init_mc2_tokens_capacity(self):
|
|
# NOTE: To be clear, we need to make sure that during graph capture, the number of
|
|
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
|
|
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
|
|
if self.compilation_config.cudagraph_capture_sizes:
|
|
max_num_tokens = self.compilation_config.cudagraph_capture_sizes[0]
|
|
else:
|
|
# NOTE: To save memory, we cap the max number of tokens to 512.
|
|
max_num_tokens = min(
|
|
self.max_num_reqs * self.uniform_decode_query_len, 512)
|
|
tp_size = self.parallel_config.tensor_parallel_size
|
|
# Use integer arithmetic for ceiling division.
|
|
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
|
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
|
|
|
|
def _make_buffer(self,
|
|
*size: Union[int, torch.SymInt],
|
|
dtype: torch.dtype,
|
|
numpy: bool = True) -> CpuGpuBuffer:
|
|
# Bfloat16 torch tensors cannot be directly cast to a numpy array, so
|
|
# if a bfloat16 buffer is needed without a corresponding numpy array,
|
|
# don't bother instantiating the numpy array.
|
|
return CpuGpuBuffer(*size,
|
|
dtype=dtype,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
with_numpy=numpy)
|
|
|
|
def _update_states_after_model_execute(
|
|
self, output_token_ids: torch.Tensor) -> None:
|
|
"""Update the cached states after model execution.
|
|
|
|
This is used for MTP/EAGLE for hybrid models, as in linear attention,
|
|
only the last token's state is kept. In MTP/EAGLE, for draft tokens
|
|
the state are kept util we decide how many tokens are accepted for
|
|
each sequence, and a shifting is done during the next iteration
|
|
based on the number of accepted tokens.
|
|
"""
|
|
if not self.model_config.is_hybrid or not self.speculative_config:
|
|
return
|
|
|
|
# Find the number of accepted tokens for each sequence.
|
|
num_accepted_tokens = (torch.cat(
|
|
[
|
|
output_token_ids,
|
|
torch.full((output_token_ids.size(0), 1),
|
|
-1,
|
|
device=output_token_ids.device),
|
|
],
|
|
dim=1) == -1).int().argmax(-1).cpu().numpy()
|
|
for i, num_tokens in enumerate(num_accepted_tokens):
|
|
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
|
|
|
|
def _use_aclgraph(self) -> bool:
|
|
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
|
|
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
|
# Remove finished requests from the cached states.
|
|
for req_id in scheduler_output.finished_req_ids:
|
|
self.requests.pop(req_id, None)
|
|
|
|
# Remove the finished requests from the persistent batch.
|
|
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
|
# scheduled_req_ids overlap. This happens when a request is aborted and
|
|
# then resubmitted with the same ID. In this case, we treat them as two
|
|
# distinct requests - clearing the cached states for the first request
|
|
# and handling the second as a new request.
|
|
for req_id in scheduler_output.finished_req_ids:
|
|
self.input_batch.remove_request(req_id)
|
|
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
|
self.encoder_cache.pop(mm_hash, None)
|
|
# Remove the unscheduled requests from the persistent batch.
|
|
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
|
# or running requests that are not scheduled in this step. We remove
|
|
# them from the persistent batch but keep their cached states since
|
|
# they will be scheduled again sometime in the future.
|
|
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
|
|
cached_req_ids = self.input_batch.req_id_to_index.keys()
|
|
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
|
|
# NOTE(woosuk): The persistent batch optimization assumes that
|
|
# consecutive batches contain mostly the same requests. If batches
|
|
# have low request overlap (e.g., alternating between two distinct
|
|
# sets of requests), this optimization becomes very inefficient.
|
|
for req_id in unscheduled_req_ids:
|
|
self.input_batch.remove_request(req_id)
|
|
|
|
req_ids_to_add: list[str] = []
|
|
# Add new requests to the cached states.
|
|
for new_req_data in scheduler_output.scheduled_new_reqs:
|
|
req_id = new_req_data.req_id
|
|
sampling_params = new_req_data.sampling_params
|
|
pooling_params = new_req_data.pooling_params
|
|
|
|
if sampling_params and \
|
|
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
|
generator = torch.Generator(device=self.device)
|
|
generator.manual_seed(sampling_params.seed)
|
|
else:
|
|
generator = None
|
|
|
|
if pooling_params:
|
|
assert (task := pooling_params.task) is not None, (
|
|
"You did not set `task` in the API")
|
|
model = cast(VllmModelForPooling, self.get_model())
|
|
to_update = model.pooler.get_pooling_updates(task)
|
|
to_update.apply(pooling_params)
|
|
|
|
backward_kwargs = {}
|
|
backward_kwargs["mm_features"] = new_req_data.mm_features
|
|
|
|
self.requests[req_id] = CachedRequestState(
|
|
req_id=req_id,
|
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
|
sampling_params=sampling_params,
|
|
pooling_params=pooling_params,
|
|
generator=generator,
|
|
block_ids=new_req_data.block_ids,
|
|
num_computed_tokens=new_req_data.num_computed_tokens,
|
|
output_token_ids=[],
|
|
lora_request=new_req_data.lora_request,
|
|
**backward_kwargs,
|
|
)
|
|
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
if self.uses_mrope:
|
|
self._init_mrope_positions(self.requests[req_id])
|
|
|
|
req_ids_to_add.append(req_id)
|
|
|
|
# Update the states of the running/resumed requests.
|
|
is_last_rank = get_pp_group().is_last_rank
|
|
req_data = scheduler_output.scheduled_cached_reqs
|
|
for i, req_id in enumerate(req_data.req_ids):
|
|
req_state = self.requests[req_id]
|
|
num_computed_tokens = req_data.num_computed_tokens[i]
|
|
new_block_ids = req_data.new_block_ids[i]
|
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
|
|
|
# Update the cached states.
|
|
req_state.num_computed_tokens = num_computed_tokens
|
|
|
|
if not is_last_rank:
|
|
# When using PP, the scheduler sends the sampled tokens back,
|
|
# because there's no direct communication between the first-
|
|
# stage worker and the last-stage worker.
|
|
new_token_ids = req_data.new_token_ids[i]
|
|
# Add the sampled token(s) from the previous step (if any).
|
|
# This doesn't include "unverified" tokens like spec tokens.
|
|
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
|
|
req_state.num_tokens)
|
|
if num_new_tokens == 1:
|
|
# Avoid slicing list in most common case.
|
|
req_state.output_token_ids.append(new_token_ids[-1])
|
|
elif num_new_tokens > 0:
|
|
req_state.output_token_ids.extend(
|
|
new_token_ids[-num_new_tokens:])
|
|
|
|
# Update the block IDs.
|
|
if not resumed_from_preemption:
|
|
if new_block_ids is not None:
|
|
# Append the new blocks to the existing block IDs.
|
|
for block_ids, new_ids in zip(req_state.block_ids,
|
|
new_block_ids):
|
|
block_ids.extend(new_ids)
|
|
else:
|
|
assert new_block_ids is not None
|
|
# The request is resumed from preemption.
|
|
# Replace the existing block IDs with the new ones.
|
|
req_state.block_ids = new_block_ids
|
|
|
|
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
if req_index is None:
|
|
# The request is not in the persistent batch.
|
|
# The request was either preempted and resumed later, or was not
|
|
# scheduled in the previous step and needs to be added again.
|
|
req_ids_to_add.append(req_id)
|
|
continue
|
|
|
|
# Update the persistent batch.
|
|
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
|
num_computed_tokens)
|
|
if new_block_ids is not None:
|
|
self.input_batch.block_table.append_row(
|
|
new_block_ids, req_index)
|
|
|
|
# For the last rank, we don't need to update the token_ids_cpu
|
|
# because the sampled tokens are already cached.
|
|
if not is_last_rank:
|
|
# Add new_token_ids to token_ids_cpu.
|
|
start_token_index = num_computed_tokens
|
|
end_token_index = num_computed_tokens + len(new_token_ids)
|
|
self.input_batch.token_ids_cpu[
|
|
req_index,
|
|
start_token_index:end_token_index] = new_token_ids
|
|
self.input_batch.num_tokens_no_spec[
|
|
req_index] = end_token_index
|
|
self.input_batch.num_tokens[req_index] = end_token_index
|
|
|
|
# Add spec_token_ids to token_ids_cpu.
|
|
spec_token_ids = (
|
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
|
|
if spec_token_ids:
|
|
num_spec_tokens = len(spec_token_ids)
|
|
start_index = self.input_batch.num_tokens_no_spec[req_index]
|
|
end_token_index = start_index + num_spec_tokens
|
|
self.input_batch.token_ids_cpu[
|
|
req_index, start_index:end_token_index] = spec_token_ids
|
|
# NOTE(woosuk): `num_tokens` here may include spec tokens.
|
|
self.input_batch.num_tokens[req_index] += num_spec_tokens
|
|
|
|
# Add the new or resumed requests to the persistent batch.
|
|
# The smaller empty indices are filled first.
|
|
for req_id in req_ids_to_add:
|
|
req_state = self.requests[req_id]
|
|
self.input_batch.add_request(req_state)
|
|
|
|
# Condense the batched states if there are gaps left by removed requests
|
|
self.input_batch.condense()
|
|
# Allow attention backend to reorder the batch, potentially
|
|
self._may_reorder_batch(scheduler_output)
|
|
# Refresh batch metadata with any pending updates.
|
|
self.input_batch.refresh_metadata()
|
|
|
|
def _init_mrope_positions(self, req_state: CachedRequestState):
|
|
image_grid_thw = []
|
|
video_grid_thw = []
|
|
second_per_grid_ts = []
|
|
audio_feature_lengths = []
|
|
use_audio_in_video = False
|
|
assert req_state.mm_features is not None
|
|
for mm_feature in req_state.mm_features:
|
|
mm_item = mm_feature.data
|
|
if mm_item is None:
|
|
continue
|
|
mm_input = mm_item.get_data()
|
|
if (t := mm_input.get("image_grid_thw")) is not None:
|
|
image_grid_thw.append(t.tolist())
|
|
if (t := mm_input.get("video_grid_thw")) is not None:
|
|
video_grid_thw.append(t.tolist())
|
|
if (t := mm_input.get("second_per_grid_ts")) is not None:
|
|
second_per_grid_ts.append(t)
|
|
if (t := mm_input.get("audio_feature_lengths")) is not None:
|
|
audio_feature_lengths.append(t)
|
|
if mm_input.get("use_audio_in_video") is True:
|
|
use_audio_in_video = True
|
|
|
|
req_state.mrope_positions, req_state.mrope_position_delta = \
|
|
MRotaryEmbedding.get_input_positions_tensor(
|
|
req_state.prompt_token_ids,
|
|
hf_config=self.model_config.hf_config,
|
|
image_grid_thw=image_grid_thw,
|
|
video_grid_thw=video_grid_thw,
|
|
second_per_grid_ts=second_per_grid_ts,
|
|
audio_feature_lengths=audio_feature_lengths,
|
|
use_audio_in_video=use_audio_in_video,
|
|
)
|
|
|
|
def _sync_metadata_across_dp(
|
|
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
|
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
|
# TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in
|
|
# our case, we still need to sync the other two flags as well. So we need to
|
|
# include them in the all_reduce operation, and more over, we CANNOT skip it
|
|
# even if we are running in eager mode, which harms performance.
|
|
# FIXME: Restore the `or self.vllm_config.model_config.enforce_eager` here
|
|
# immediately once the other two flags are no longer needed.
|
|
if self.dp_size == 1:
|
|
return num_tokens, None, with_prefill, enable_dbo
|
|
|
|
# Sync num_tokens, with_prefill, enable_dbo across dp ranks
|
|
num_tokens_tensor = torch.tensor([
|
|
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
|
|
],
|
|
dtype=torch.int32,
|
|
device="npu")
|
|
|
|
flags_tensor = torch.tensor(
|
|
[int(with_prefill), int(not enable_dbo)],
|
|
dtype=torch.int32,
|
|
device="npu")
|
|
|
|
packed_tensor = torch.cat([num_tokens_tensor, flags_tensor])
|
|
|
|
dist.all_reduce(packed_tensor, group=get_dp_group().device_group)
|
|
|
|
# Unpack the results
|
|
num_tokens_across_dp = packed_tensor[:-2]
|
|
synced_flags = packed_tensor[-2:]
|
|
|
|
max_tokens_across_dp = torch.max(num_tokens_across_dp).item()
|
|
global_with_prefill = bool(synced_flags[0])
|
|
global_enable_dbo = not bool(synced_flags[1])
|
|
|
|
# Create a tensor for num_tokens_after_padding
|
|
num_tokens_after_padding = torch.tensor([max_tokens_across_dp] *
|
|
self.dp_size,
|
|
device="cpu",
|
|
dtype=torch.int32)
|
|
|
|
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo
|
|
|
|
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
|
|
attn_state: AscendAttentionState,
|
|
num_tokens: int) -> bool:
|
|
# do the checks for dp + dbo
|
|
if attn_state in [
|
|
AscendAttentionState.DecodeOnly,
|
|
AscendAttentionState.SpecDecoding
|
|
]:
|
|
return False
|
|
# considering the case that one dp rank may enable dbo while others may not
|
|
if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO:
|
|
return False
|
|
# TODO: remove it if token-level microbatch is enabled
|
|
[token_index,
|
|
seq_index] = compute_split_seq_index(query_lens, attn_state,
|
|
num_tokens)
|
|
if token_index == 0 or seq_index == 0 or seq_index == len(
|
|
query_lens) or num_tokens < 256:
|
|
return False
|
|
return True
|
|
|
|
def get_model(self) -> nn.Module:
|
|
# get raw model out of the aclgraph wrapper.
|
|
if isinstance(self.model, ACLGraphWrapper):
|
|
return self.model.unwrap()
|
|
return self.model
|
|
|
|
def get_supported_generation_tasks(self) -> "list[GenerationTask]":
|
|
model = self.get_model()
|
|
supported_tasks = list[GenerationTask]()
|
|
|
|
if is_text_generation_model(model):
|
|
supported_tasks.append("generate")
|
|
|
|
if supports_transcription(model):
|
|
if model.supports_transcription_only:
|
|
return ["transcription"]
|
|
|
|
supported_tasks.append("transcription")
|
|
|
|
return supported_tasks
|
|
|
|
def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
|
|
tasks = list[SupportedTask]()
|
|
|
|
if self.model_config.runner_type == "generate":
|
|
tasks.extend(self.get_supported_generation_tasks())
|
|
if self.model_config.runner_type == "pooling":
|
|
tasks.extend(self.get_supported_pooling_tasks())
|
|
|
|
return tuple(tasks)
|
|
|
|
def _make_attention_mask(self, seq_lens, position,
|
|
attn_state) -> torch.Tensor:
|
|
# Pooling situation.
|
|
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
|
return self.attn_mask_builder.get_pooling_mask(self.device)
|
|
# Chunk Prefill situation.
|
|
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
|
|
if torch.version.cann.startswith("8.3"):
|
|
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
|
else:
|
|
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
|
seq_lens, position, self.dtype, self.device)
|
|
|
|
# Prefill without cache situation.
|
|
elif attn_state == AscendAttentionState.PrefillNoCache:
|
|
max_seq_len = max(seq_lens.max().item(), 0)
|
|
return self.attn_mask_builder.get_attn_mask(
|
|
max_seq_len, self.dtype, self.device)
|
|
# Prefill with cache hit.
|
|
elif attn_state == AscendAttentionState.PrefillCacheHit:
|
|
return self.attn_mask_builder.get_attn_mask(
|
|
128, self.dtype, self.device)
|
|
# Decode-only situation.
|
|
else:
|
|
return None
|
|
|
|
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
|
mrope_pos_ptr = 0
|
|
for index, req_id in enumerate(self.input_batch.req_ids):
|
|
req = self.requests[req_id]
|
|
assert req.mrope_positions is not None
|
|
|
|
num_computed_tokens = \
|
|
self.input_batch.num_computed_tokens_cpu[index]
|
|
num_scheduled_tokens = \
|
|
scheduler_output.num_scheduled_tokens[req_id]
|
|
num_prompt_tokens = len(req.prompt_token_ids)
|
|
|
|
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
|
|
prompt_part_len = max(0,
|
|
num_prompt_tokens - num_computed_tokens)
|
|
completion_part_len = max(
|
|
0, num_scheduled_tokens - prompt_part_len)
|
|
else:
|
|
prompt_part_len = num_scheduled_tokens
|
|
completion_part_len = 0
|
|
|
|
assert num_scheduled_tokens == prompt_part_len + completion_part_len
|
|
|
|
if prompt_part_len > 0:
|
|
# prompt's mrope_positions are pre-computed
|
|
dst_start = mrope_pos_ptr
|
|
dst_end = mrope_pos_ptr + prompt_part_len
|
|
src_start = num_computed_tokens
|
|
src_end = num_computed_tokens + prompt_part_len
|
|
|
|
self.mrope_positions_cpu[:, dst_start:dst_end] = \
|
|
req.mrope_positions[:,src_start:src_end]
|
|
|
|
mrope_pos_ptr += prompt_part_len
|
|
|
|
if completion_part_len > 0:
|
|
# compute completion's mrope_positions on-the-fly
|
|
dst_start = mrope_pos_ptr
|
|
dst_end = mrope_pos_ptr + completion_part_len
|
|
MRotaryEmbedding.get_next_input_positions_tensor(
|
|
out=self.mrope_positions_np,
|
|
out_offset=dst_start,
|
|
mrope_position_delta=req.mrope_position_delta,
|
|
context_len=num_computed_tokens + prompt_part_len,
|
|
num_new_tokens=completion_part_len,
|
|
)
|
|
|
|
mrope_pos_ptr += completion_part_len
|
|
|
|
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
|
|
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
|
if not scheduled_encoder_inputs:
|
|
return
|
|
|
|
# Batch the multi-modal inputs.
|
|
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
|
|
scheduler_output)
|
|
encoder_outputs = []
|
|
|
|
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
|
mm_kwargs,
|
|
device=self.device,
|
|
pin_memory=True,
|
|
):
|
|
# Run the encoder.
|
|
# `curr_group_outputs` is either of the following:
|
|
# 1. A tensor of shape (num_items, feature_size, hidden_size)
|
|
# in case feature_size is fixed across all multimodal items.
|
|
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
|
# (feature_size, hidden_size) in case the feature size is dynamic
|
|
# depending on the input multimodal items.
|
|
curr_group_outputs = self.model.get_multimodal_embeddings(
|
|
**mm_kwargs_group)
|
|
|
|
sanity_check_mm_encoder_outputs(
|
|
curr_group_outputs,
|
|
expected_num_items=num_items,
|
|
)
|
|
|
|
for output in curr_group_outputs:
|
|
encoder_outputs.append(output)
|
|
|
|
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
|
|
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
|
|
output,
|
|
is_embed=pos_info.is_embed,
|
|
)
|
|
|
|
def _batch_mm_kwargs_from_scheduler(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]:
|
|
"""Batch multimodal kwargs from scheduled encoder inputs.
|
|
|
|
Args:
|
|
scheduler_output: The scheduler output containing scheduled encoder
|
|
inputs.
|
|
|
|
Returns:
|
|
A tuple of (mm_kwargs, req_ids_pos) where:
|
|
- mm_kwargs: List of multimodal kwargs items to be batched
|
|
- mm_hashes_pos: List of (mm_hash, position_info) tuples
|
|
"""
|
|
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
|
if not scheduled_encoder_inputs:
|
|
return [], []
|
|
# Batch the multi-modal inputs.
|
|
mm_kwargs = list[MultiModalKwargsItem]()
|
|
# list of tuple (mm_hash, position_info)
|
|
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
|
|
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
|
req_state = self.requests[req_id]
|
|
assert req_state.mm_features is not None
|
|
for mm_input_id in encoder_input_ids:
|
|
mm_feature = req_state.mm_features[mm_input_id]
|
|
mm_hash = mm_feature.identifier
|
|
mm_kwargs.append(mm_feature.data)
|
|
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
|
|
|
|
return mm_kwargs, mm_hashes_pos
|
|
|
|
def _gather_mm_embeddings(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> list[torch.Tensor]:
|
|
|
|
def _iter_mm_features(req_state: CachedRequestState):
|
|
assert req_state.mm_features is not None
|
|
for mm_feature in req_state.mm_features:
|
|
pos_info = mm_feature.mm_position
|
|
yield mm_feature.identifier, pos_info, getattr(
|
|
pos_info, "is_embed", None)
|
|
|
|
mm_embeds: list[torch.Tensor] = []
|
|
|
|
for req_id in self.input_batch.req_ids:
|
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
|
req_id]
|
|
req_state = self.requests[req_id]
|
|
num_computed_tokens = req_state.num_computed_tokens
|
|
|
|
for mm_hash, pos_info, is_embed in _iter_mm_features(req_state):
|
|
start_pos = pos_info.offset
|
|
num_encoder_tokens = pos_info.length
|
|
|
|
if start_pos >= num_computed_tokens + num_scheduled_tokens:
|
|
break
|
|
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
|
continue
|
|
|
|
start_idx = max(num_computed_tokens - start_pos, 0)
|
|
end_idx = min(
|
|
num_computed_tokens - start_pos + num_scheduled_tokens,
|
|
num_encoder_tokens,
|
|
)
|
|
assert start_idx < end_idx
|
|
|
|
encoder_output = self.encoder_cache.get(mm_hash, None)
|
|
assert encoder_output is not None, \
|
|
f"Encoder cache miss for {mm_hash}."
|
|
|
|
if is_embed is not None:
|
|
is_embed = is_embed[start_idx:end_idx]
|
|
|
|
mm_embeds_item = gather_mm_placeholders(
|
|
encoder_output[start_idx:end_idx],
|
|
is_embed=is_embed,
|
|
)
|
|
mm_embeds.append(mm_embeds_item)
|
|
return mm_embeds
|
|
|
|
def _get_cumsum_and_arange(
|
|
self,
|
|
num_tokens: np.ndarray,
|
|
cumsum_dtype: Optional[np.dtype] = None,
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
"""Get the cumulative sum and batched arange of the given array.
|
|
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
|
|
# Equivalent to but faster than:
|
|
# np.concatenate([np.arange(n) for n in num_tokens])
|
|
"""
|
|
# Step 1. [2, 5, 3] -> [2, 7, 10]
|
|
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
|
|
total_num_tokens = cu_num_tokens[-1]
|
|
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
|
|
cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
|
|
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
arange = self.arange_np[:total_num_tokens] - cumsums_offsets
|
|
|
|
return cu_num_tokens, arange
|
|
|
|
def _prepare_input_ids(self, total_num_scheduled_tokens: int,
|
|
cu_num_tokens: np.ndarray) -> None:
|
|
"""Prepare the input IDs for the current batch.
|
|
|
|
Carefully handles the `prev_sampled_token_ids` which can be cached
|
|
from the previous engine iteration, in which case those tokens on the
|
|
NPU need to be copied into the corresponding slots into input_ids."""
|
|
|
|
if self.input_batch.prev_sampled_token_ids is None:
|
|
# Normal scheduling case
|
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
|
self.input_ids_cpu[:total_num_scheduled_tokens],
|
|
non_blocking=True)
|
|
return
|
|
|
|
# Async scheduling case, where some decode requests from the previous
|
|
# iteration won't have entries in input_ids_cpu and need to be copied
|
|
# on the NPU from prev_sampled_token_ids.
|
|
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
|
|
assert prev_req_id_to_index is not None
|
|
flattened_indices = []
|
|
prev_common_req_indices = []
|
|
indices_match = True
|
|
max_flattened_index = -1
|
|
for req_id, cur_index in self.input_batch.req_id_to_index.items():
|
|
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
|
|
prev_common_req_indices.append(prev_index)
|
|
# We need to compute the flattened input_ids index of the
|
|
# last token in each common request.
|
|
flattened_index = cu_num_tokens[cur_index].item() - 1
|
|
flattened_indices.append(flattened_index)
|
|
indices_match &= (prev_index == flattened_index)
|
|
max_flattened_index = max(max_flattened_index, flattened_index)
|
|
num_commmon_tokens = len(flattened_indices)
|
|
if num_commmon_tokens < total_num_scheduled_tokens:
|
|
# If not all requests are decodes from the last iteration,
|
|
# We need to copy the input_ids_cpu to the NPU first.
|
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
|
self.input_ids_cpu[:total_num_scheduled_tokens],
|
|
non_blocking=True)
|
|
if num_commmon_tokens == 0:
|
|
# No requests in common with the previous iteration
|
|
# So input_ids_cpu will have all the input ids.
|
|
return
|
|
if indices_match and max_flattened_index == (num_commmon_tokens - 1):
|
|
# Common-case optimization: the batch is unchanged
|
|
# and no reordering happened.
|
|
# The indices are both the same permutation of 0..N-1 so
|
|
# we can copy directly using a single slice.
|
|
self.input_ids[:num_commmon_tokens].copy_(
|
|
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
|
|
0],
|
|
non_blocking=True)
|
|
return
|
|
# Upload the index tensors asynchronously
|
|
# so the scatter can be non-blocking.
|
|
input_ids_index_tensor = torch.tensor(flattened_indices,
|
|
dtype=torch.int64,
|
|
pin_memory=self.pin_memory).to(
|
|
self.device,
|
|
non_blocking=True)
|
|
prev_common_req_indices_tensor = torch.tensor(
|
|
prev_common_req_indices,
|
|
dtype=torch.int64,
|
|
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
|
|
self.input_ids.scatter_(dim=0,
|
|
index=input_ids_index_tensor,
|
|
src=self.input_batch.prev_sampled_token_ids[
|
|
prev_common_req_indices_tensor, 0])
|
|
|
|
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
|
"""
|
|
Update the order of requests in the batch based on the attention
|
|
backend's needs. For example, some attention backends (namely MLA) may
|
|
want to separate requests based on if the attention computation will be
|
|
compute-bound or memory-bound.
|
|
|
|
Args:
|
|
scheduler_output: The scheduler output.
|
|
"""
|
|
# Attention free models have zero kv_cache_goups, however models
|
|
# like Mamba are also attention free but use the kv_cache for
|
|
# keeping its internal state. This is why we check the number
|
|
# of kv_cache groups instead of solely checking
|
|
# for self.model_config.is_attention_free.
|
|
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
|
return
|
|
|
|
if self.reorder_batch_threshold is not None:
|
|
reorder_batch_to_split_decodes_and_prefills(
|
|
self.input_batch,
|
|
scheduler_output,
|
|
decode_threshold=self.reorder_batch_threshold)
|
|
|
|
def _prepare_inputs(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor,
|
|
int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor],
|
|
Optional[torch.Tensor], Optional[torch.Tensor], int]:
|
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
assert total_num_scheduled_tokens > 0
|
|
num_reqs = self.input_batch.num_reqs
|
|
assert num_reqs > 0
|
|
|
|
# OPTIMIZATION: Start copying the block table first.
|
|
# This way, we can overlap the copy with the following CPU operations.
|
|
self.input_batch.block_table.commit_block_table(num_reqs)
|
|
|
|
# Get the number of scheduled tokens for each request.
|
|
req_ids = self.input_batch.req_ids
|
|
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
|
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
|
max_num_scheduled_tokens = num_scheduled_tokens.max()
|
|
num_valid_tokens = np.array([
|
|
num_tokens -
|
|
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
|
|
for num_tokens, i in zip(tokens, req_ids)
|
|
],
|
|
dtype=np.int32)
|
|
|
|
if (self.use_aclgraph and total_num_scheduled_tokens
|
|
<= self.aclgraph_batch_sizes[-1]):
|
|
# Add padding to the batch size.
|
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
|
total_num_scheduled_tokens)
|
|
elif self.use_aclgraph and enable_sp(self.vllm_config):
|
|
# When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size,
|
|
# the model will fall back to running its FX graph in eager mode.
|
|
# In this case, when sequence parallelism is enabled, we need to pad tokens to align
|
|
# with tp_size because pad_size cannot be captured by the FX graph
|
|
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
|
num_input_tokens = math.ceil(
|
|
total_num_scheduled_tokens / tp_size) * tp_size
|
|
else:
|
|
# Eager mode.
|
|
num_input_tokens = total_num_scheduled_tokens
|
|
|
|
# Get the attention state.
|
|
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
|
num_valid_tokens)
|
|
self.attn_state = attn_state # type: ignore
|
|
|
|
# Determine if it's a splitfuse batch
|
|
with_prefill = attn_state not in [
|
|
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
|
]
|
|
|
|
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
|
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
|
|
attn_state,
|
|
total_num_scheduled_tokens)
|
|
|
|
# Get info across DP ranks.
|
|
# NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP,
|
|
# Otherwise, it's just max_tokens_across_dp_cpu
|
|
(maybe_padded_num_tokens, num_tokens_across_dp, with_prefill,
|
|
enable_dbo) = self._sync_metadata_across_dp(num_input_tokens,
|
|
with_prefill, enable_dbo)
|
|
|
|
# TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens
|
|
# We should consider removing maybe_padded_num_tokens later
|
|
num_input_tokens = maybe_padded_num_tokens
|
|
|
|
# Hot-Swap lora model
|
|
if self.lora_config:
|
|
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
|
|
|
# Get request indices.
|
|
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
|
req_indices = np.repeat(self.arange_np[:num_reqs],
|
|
num_scheduled_tokens)
|
|
|
|
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
|
|
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
|
num_scheduled_tokens)
|
|
|
|
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
|
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
|
arange,
|
|
out=positions_np)
|
|
|
|
# Calculate M-RoPE positions.
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
if self.uses_mrope:
|
|
self._calc_mrope_positions(scheduler_output)
|
|
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
|
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
|
non_blocking=True)
|
|
|
|
# Get token indices.
|
|
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
|
# where M is the max_model_len.
|
|
token_indices = (positions_np +
|
|
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
|
|
|
# Prepare input_ids.
|
|
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
|
# because torch.index_select is much faster than np.take for large
|
|
# tensors.
|
|
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
|
0,
|
|
torch.from_numpy(token_indices),
|
|
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
|
|
|
# Prepare some information for building Attention-Metadata
|
|
# Compute and commit slot mapping
|
|
self.input_batch.block_table.compute_slot_mapping(
|
|
req_indices, positions_np)
|
|
self.input_batch.block_table.commit_slot_mapping(
|
|
total_num_scheduled_tokens)
|
|
|
|
self.query_start_loc_np[0] = 0
|
|
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
|
self.query_start_loc[:num_reqs + 1].copy_(
|
|
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
|
|
|
self.seq_lens_np[:num_reqs] = (
|
|
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
|
num_scheduled_tokens)
|
|
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
|
non_blocking=True)
|
|
|
|
# Fill unused with -1. Needed for reshape_and_cache
|
|
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
|
self.seq_lens[num_reqs:].fill_(0)
|
|
|
|
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
|
|
|
# Copy the tensors to the NPU.
|
|
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
|
|
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
|
|
self.positions[:num_input_tokens].copy_(
|
|
self.positions_cpu[:num_input_tokens], non_blocking=True)
|
|
|
|
# Make Attention metadata
|
|
positions_cpu = self.positions_cpu[:num_input_tokens]
|
|
positions = self.positions[:num_input_tokens]
|
|
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
|
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
|
num_valid_tokens)
|
|
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
|
|
position=positions_cpu,
|
|
attn_state=attn_state)
|
|
self.attn_state = attn_state # type: ignore
|
|
|
|
self.with_prefill = with_prefill
|
|
self.num_tokens_across_dp = num_tokens_across_dp
|
|
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
|
|
attn_metadata: dict[str, Any] = {}
|
|
|
|
# _prepare_inputs may reorder the batch, so we must gather
|
|
# multi-modal outputs after that to ensure the correct order
|
|
if self.is_multimodal_model:
|
|
# Run the multimodal encoder if any.
|
|
self._execute_mm_encoder(scheduler_output)
|
|
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
|
|
|
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
|
# embeddings), we always use embeddings (rather than token ids)
|
|
# as input to the multimodal model, even when the input is text.
|
|
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
|
if mm_embeds:
|
|
inputs_embeds = self.model.get_input_embeddings(
|
|
input_ids, mm_embeds)
|
|
else:
|
|
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
|
# TODO(woosuk): Avoid the copy. Optimize.
|
|
self.inputs_embeds[:total_num_scheduled_tokens].copy_(
|
|
inputs_embeds)
|
|
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
|
input_ids = None
|
|
else:
|
|
# For text-only models, we use token ids as input.
|
|
# While it is possible to use embeddings as input just like the
|
|
# multimodal models, it is not desirable for performance since
|
|
# then the embedding layer is not included in the ACL graph.
|
|
input_ids = self.input_ids[:num_input_tokens]
|
|
inputs_embeds = None
|
|
positions = self.positions[:num_input_tokens]
|
|
input_ids, positions = self._update_input_ids_and_positions(
|
|
input_ids, positions, num_input_tokens, with_prefill,
|
|
maybe_padded_num_tokens)
|
|
|
|
if get_pp_group().is_first_rank:
|
|
intermediate_tensors = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
assert self.intermediate_tensors is not None
|
|
for k, v in intermediate_tensors.items():
|
|
self.intermediate_tensors[k][:num_input_tokens].copy_(
|
|
v[:num_input_tokens], non_blocking=True)
|
|
intermediate_tensors = IntermediateTensors({
|
|
k: v[:num_input_tokens]
|
|
for k, v in self.intermediate_tensors.items()
|
|
})
|
|
|
|
use_spec_decode = len(
|
|
scheduler_output.scheduled_spec_decode_tokens) > 0
|
|
if not use_spec_decode:
|
|
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
|
# partial requests. While we should not sample any token
|
|
# from these partial requests, we do so for simplicity.
|
|
# We will ignore the sampled tokens from the partial requests.
|
|
# TODO: Support prompt logprobs.
|
|
spec_decode_metadata = None
|
|
logits_indices = torch.from_numpy(cu_num_tokens - 1).to(
|
|
self.device, non_blocking=True)
|
|
else:
|
|
# Get the number of draft tokens for each request.
|
|
# Iterate over the dictionary rather than all requests since not all
|
|
# requests have draft tokens.
|
|
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
|
for req_id, draft_token_ids in (
|
|
scheduler_output.scheduled_spec_decode_tokens.items()):
|
|
req_idx = self.input_batch.req_id_to_index[req_id]
|
|
num_draft_tokens[req_idx] = len(draft_token_ids)
|
|
|
|
spec_decode_metadata = self._calc_spec_decode_metadata(
|
|
num_draft_tokens, cu_num_tokens)
|
|
logits_indices = spec_decode_metadata.logits_indices
|
|
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
|
|
self.num_draft_tokens.np[num_reqs:].fill(0)
|
|
self.num_draft_tokens.copy_to_gpu()
|
|
|
|
# Used in the below loop.
|
|
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
|
num_computed_tokens_cpu = (
|
|
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
|
spec_decode_common_attn_metadata = None
|
|
if use_spec_decode and self.need_accepted_tokens:
|
|
self.num_accepted_tokens.np[:num_reqs] = (
|
|
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
|
|
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
|
self.num_accepted_tokens.copy_to_gpu()
|
|
|
|
# Prepare the attention metadata for each KV cache group and make layers
|
|
# in the same group share the same metadata.
|
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
|
self.kv_cache_config.kv_cache_groups):
|
|
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
|
EncoderOnlyAttentionSpec):
|
|
# Encoder-only layers do not have KV cache, so we need to
|
|
# create a dummy block table and slot mapping for them.
|
|
blk_table_tensor = torch.zeros(
|
|
(num_reqs, 1),
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
slot_mapping = torch.zeros(
|
|
(total_num_scheduled_tokens, ),
|
|
dtype=torch.int64,
|
|
device=self.device,
|
|
)
|
|
else:
|
|
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
|
blk_table_tensor = blk_table.get_device_tensor()
|
|
slot_mapping = blk_table.slot_mapping_cpu[:
|
|
total_num_scheduled_tokens]
|
|
self.slot_mapping[:total_num_scheduled_tokens].copy_(
|
|
slot_mapping[:total_num_scheduled_tokens],
|
|
non_blocking=True,
|
|
)
|
|
self.slot_mapping[total_num_scheduled_tokens:].fill_(0)
|
|
|
|
# Make AscendCommonAttentionMetadata
|
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
|
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
|
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
|
seq_lens_cpu=self.seq_lens_cpu,
|
|
seq_lens=self.seq_lens_cpu[:num_reqs],
|
|
num_reqs=num_reqs,
|
|
num_actual_tokens=total_num_scheduled_tokens,
|
|
num_input_tokens=num_input_tokens,
|
|
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
|
# TODO: change this to the right block table for linear attn
|
|
block_table_tensor=blk_table_tensor[:num_reqs],
|
|
slot_mapping=self.slot_mapping,
|
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
|
positions=self.positions,
|
|
attn_mask=self.attn_mask,
|
|
spec_attn_mask=self.spec_attn_mask,
|
|
attn_state=self.attn_state,
|
|
enable_dbo_across_dp=enable_dbo,
|
|
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
|
|
max_query_len=max_num_scheduled_tokens,
|
|
graph_pad_size=self.graph_pad_size,
|
|
decode_token_per_req=self.decode_token_per_req,
|
|
cos=self.cos,
|
|
sin=self.sin,
|
|
)
|
|
|
|
if self.speculative_config and \
|
|
spec_decode_common_attn_metadata is None:
|
|
spec_decode_common_attn_metadata = common_attn_metadata
|
|
|
|
for attn_group in self.attn_groups[kv_cache_group_id]:
|
|
common_prefix_len = 0
|
|
extra_attn_metadata_args = {}
|
|
builder = attn_group.get_metadata_builder()
|
|
if isinstance(builder, GDNAttentionMetadataBuilder
|
|
) or self.model_config.runner_type == "pooling":
|
|
if use_spec_decode:
|
|
extra_attn_metadata_args = dict(
|
|
num_accepted_tokens=self.num_accepted_tokens.
|
|
gpu[:num_reqs],
|
|
num_draft_tokens=self.num_draft_tokens.
|
|
gpu[:num_reqs],
|
|
)
|
|
attn_metadata_i = builder.build(
|
|
common_prefix_len=common_prefix_len,
|
|
common_attn_metadata=common_attn_metadata,
|
|
**extra_attn_metadata_args)
|
|
else:
|
|
attn_metadata_i = builder.build(
|
|
common_prefix_len=common_prefix_len,
|
|
common_attn_metadata=common_attn_metadata,
|
|
model=self.get_model(),
|
|
**extra_attn_metadata_args)
|
|
|
|
for layer_name in attn_group.layer_names:
|
|
attn_metadata[layer_name] = attn_metadata_i
|
|
|
|
if lmhead_tp_enable():
|
|
max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
|
|
logits_indices = nn.functional.pad(
|
|
logits_indices,
|
|
(0, max_num_reqs_across_dp - logits_indices.shape[0]))
|
|
|
|
return (attn_metadata, positions, num_scheduled_tokens,
|
|
num_input_tokens, num_tokens_across_dp,
|
|
maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
|
|
input_ids, inputs_embeds, intermediate_tensors,
|
|
max_num_scheduled_tokens)
|
|
|
|
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
|
maybe_padded_num_tokens,
|
|
input_ids, positions,
|
|
intermediate_tensors,
|
|
inputs_embeds):
|
|
assert self.model is not None
|
|
hidden_states = self.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
|
|
forward_context = get_forward_context()
|
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
|
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
|
|
if self.vllm_config.model_config.use_mla:
|
|
# FIXME: Try using `auto_dispatch_capture=True`
|
|
update_mla_attn_params(self.update_stream, forward_context,
|
|
maybe_padded_num_tokens,
|
|
self.speculative_config)
|
|
else:
|
|
update_attn_params(self.update_stream, forward_context,
|
|
maybe_padded_num_tokens)
|
|
|
|
if get_forward_context().sp_enabled:
|
|
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
|
pad_size = get_forward_context().pad_size
|
|
if pad_size > 0:
|
|
hidden_states = hidden_states[:-pad_size, :]
|
|
return hidden_states
|
|
|
|
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
|
|
num_valid_tokens):
|
|
ascend_config = get_ascend_config()
|
|
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
|
attn_state = AscendAttentionState.PrefillNoCache
|
|
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
|
elif np.all(num_scheduled_tokens == 1):
|
|
attn_state = AscendAttentionState.DecodeOnly
|
|
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
|
|
# SpecDecoding now supports seq_len=1 and seq_len=2
|
|
# In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1
|
|
attn_state = AscendAttentionState.SpecDecoding
|
|
# Speculative decoding.
|
|
elif np.all(num_valid_tokens == 1):
|
|
if self.drafter and (self.drafter.name == SpecDcodeType.EAGLE
|
|
or self.drafter.name == SpecDcodeType.EAGLE3):
|
|
attn_state = AscendAttentionState.ChunkedPrefill
|
|
else:
|
|
attn_state = AscendAttentionState.SpecDecoding
|
|
# splitfuse
|
|
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
|
|
attn_state = AscendAttentionState.ChunkedPrefill
|
|
else:
|
|
attn_state = AscendAttentionState.PrefillCacheHit
|
|
return attn_state
|
|
|
|
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
|
|
self.graph_pad_size = -1
|
|
|
|
def _update_input_ids_and_positions(self, input_ids, positions,
|
|
num_input_tokens, with_prefill,
|
|
maybe_padded_num_tokens):
|
|
if self.uses_mrope:
|
|
positions = self.mrope_positions[:, :num_input_tokens]
|
|
return input_ids, positions
|
|
|
|
def _calc_spec_decode_metadata(
|
|
self,
|
|
num_draft_tokens: np.ndarray,
|
|
cu_num_scheduled_tokens: np.ndarray,
|
|
) -> SpecDecodeMetadata:
|
|
# Inputs:
|
|
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
|
|
# num_draft_tokens: [ 3, 0, 2, 0, 1]
|
|
# Outputs:
|
|
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
|
|
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
|
|
# 206, 207, 208]
|
|
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
|
|
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
|
|
|
|
# Compute the logits indices.
|
|
# [4, 1, 3, 1, 2]
|
|
num_sampled_tokens = num_draft_tokens + 1
|
|
# Step 1. [4, 5, 8, 9, 11]
|
|
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
|
|
total_num_sampled_tokens = cu_num_sampled_tokens[-1]
|
|
# Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
|
cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
|
|
num_sampled_tokens)
|
|
# Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
|
arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
|
|
# Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
|
logits_indices = np.repeat(
|
|
cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
|
|
# Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
|
logits_indices += arange
|
|
|
|
# Compute the bonus logits indices.
|
|
bonus_logits_indices = cu_num_sampled_tokens - 1
|
|
|
|
# Compute the draft logits indices.
|
|
# [3, 3, 5, 5, 6]
|
|
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
|
|
total_num_draft_tokens = cu_num_draft_tokens[-1]
|
|
# [0, 0, 0, 3, 3, 5]
|
|
cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
|
|
num_draft_tokens)
|
|
# [0, 1, 2, 0, 1, 0]
|
|
arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
|
|
# [0, 0, 0, 5, 5, 9]
|
|
target_logits_indices = np.repeat(
|
|
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
|
|
# [0, 1, 2, 5, 6, 9]
|
|
target_logits_indices += arange
|
|
|
|
# TODO: Optimize the CPU -> NPU copy.
|
|
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
|
|
self.device, non_blocking=True)
|
|
logits_indices = torch.from_numpy(logits_indices).to(self.device,
|
|
non_blocking=True)
|
|
target_logits_indices = torch.from_numpy(target_logits_indices).to(
|
|
self.device, non_blocking=True)
|
|
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
|
|
self.device, non_blocking=True)
|
|
|
|
# Compute the draft token ids.
|
|
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
|
|
draft_token_ids = self.input_ids[logits_indices]
|
|
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
|
|
|
metadata = SpecDecodeMetadata(
|
|
draft_token_ids=draft_token_ids,
|
|
num_draft_tokens=num_draft_tokens.tolist(),
|
|
cu_num_draft_tokens=cu_num_draft_tokens,
|
|
target_logits_indices=target_logits_indices,
|
|
bonus_logits_indices=bonus_logits_indices,
|
|
logits_indices=logits_indices,
|
|
)
|
|
return metadata
|
|
|
|
def apply_grammar_bitmask(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
logits: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
grammar_bitmask = scheduler_output.grammar_bitmask
|
|
|
|
# We receive the structured output bitmask from the scheduler,
|
|
# compacted to contain bitmasks only for structured output requests.
|
|
# The order of the requests in the bitmask is not guaranteed to be the
|
|
# same as the order of the requests in the gpu runner's batch. We need
|
|
# to sort the bitmask to match the order of the requests used here.
|
|
|
|
# Get the batch indices of the structured output requests.
|
|
# Keep track of the number of speculative tokens scheduled for every
|
|
# request in the batch, as the logit indices are offset by this amount.
|
|
struct_out_req_batch_indices: dict[str, int] = {}
|
|
cumulative_offset = 0
|
|
seq = sorted(self.input_batch.req_id_to_index.items(),
|
|
key=lambda x: x[1])
|
|
for req_id, batch_index in seq:
|
|
logit_index = batch_index + cumulative_offset
|
|
cumulative_offset += len(
|
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
|
if req_id in scheduler_output.structured_output_request_ids:
|
|
struct_out_req_batch_indices[req_id] = logit_index
|
|
|
|
out_indices = []
|
|
|
|
# Reorder the bitmask to match the order of the requests in the batch.
|
|
sorted_bitmask = np.zeros_like(grammar_bitmask,
|
|
shape=(logits.shape[0],
|
|
grammar_bitmask.shape[1]))
|
|
cumulative_index = 0
|
|
seq = sorted(scheduler_output.structured_output_request_ids.items(),
|
|
key=lambda x: x[1])
|
|
for req_id, _ in seq:
|
|
logit_index = struct_out_req_batch_indices[req_id]
|
|
num_spec_tokens = len(
|
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
|
for i in range(1 + num_spec_tokens):
|
|
sorted_bitmask[logit_index + i] = \
|
|
grammar_bitmask[cumulative_index + i]
|
|
out_indices.append(logit_index + i)
|
|
cumulative_index += 1 + num_spec_tokens
|
|
grammar_bitmask = sorted_bitmask
|
|
|
|
# Serialization of np.ndarray is much more efficient than a tensor,
|
|
# so we receive it in that format.
|
|
grammar_bitmask = torch.from_numpy(grammar_bitmask)
|
|
|
|
# NOTE:
|
|
# 1. XGrammar bitmask applying only supports CPU and GPU.
|
|
# 2. The logits and bitmask should be on the same device.
|
|
# 3. XGrammar logits on CPU only supports float32 dtype.
|
|
logits_dtype = logits.dtype
|
|
logits = logits.to("cpu").float()
|
|
xgr.apply_token_bitmask_inplace(
|
|
logits,
|
|
grammar_bitmask,
|
|
indices=out_indices,
|
|
)
|
|
return logits.to(self.device).to(logits_dtype)
|
|
|
|
def propose_draft_token_ids(
|
|
self,
|
|
valid_sampled_token_ids: list[list[int]],
|
|
sampling_metadata: SamplingMetadata,
|
|
scheduler_output: "SchedulerOutput",
|
|
spec_decode_metadata: SpecDecodeMetadata,
|
|
positions: torch.Tensor,
|
|
num_scheduled_tokens: int,
|
|
hidden_states: torch.Tensor,
|
|
attn_metadata: dict[str, Any],
|
|
aux_hidden_states: torch.Tensor = None,
|
|
) -> Optional[list[list[int]]]:
|
|
if not self.drafter:
|
|
# Speculative decoding is not enabled.
|
|
draft_token_ids = None
|
|
else:
|
|
draft_token_ids = self.drafter.generate_token_ids(
|
|
valid_sampled_token_ids, sampling_metadata, scheduler_output,
|
|
spec_decode_metadata, positions, num_scheduled_tokens,
|
|
hidden_states, attn_metadata, aux_hidden_states)
|
|
return draft_token_ids
|
|
|
|
def _pool(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
num_scheduled_tokens: int,
|
|
num_scheduled_tokens_np: np.ndarray,
|
|
finished_sending: Optional[set[str]] = None,
|
|
finished_recving: Optional[set[str]] = None,
|
|
kv_connector_output: Optional["KVConnectorOutput"] = None,
|
|
) -> ModelRunnerOutput:
|
|
assert self.input_batch.num_reqs ==\
|
|
len(self.input_batch.pooling_params), \
|
|
"Either all or none of the requests in" \
|
|
" a batch must be pooling request"
|
|
|
|
hidden_states = hidden_states[:num_scheduled_tokens]
|
|
pooling_metadata = self.input_batch.pooling_metadata
|
|
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
|
|
device=hidden_states.device)
|
|
seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]
|
|
|
|
model = cast(VllmModelForPooling, self.model)
|
|
raw_pooler_output = model.pooler(
|
|
hidden_states=hidden_states,
|
|
pooling_metadata=pooling_metadata,
|
|
)
|
|
raw_pooler_output = json_map_leaves(
|
|
lambda x: x.to("cpu", non_blocking=True),
|
|
raw_pooler_output,
|
|
)
|
|
torch.npu.synchronize()
|
|
|
|
pooler_output: list[Optional[torch.Tensor]] = []
|
|
for raw_output, seq_len, prompt_len in zip(
|
|
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
|
output = raw_output if seq_len == prompt_len else None
|
|
pooler_output.append(output)
|
|
|
|
return ModelRunnerOutput(
|
|
req_ids=self.input_batch.req_ids,
|
|
req_id_to_index=self.input_batch.req_id_to_index,
|
|
sampled_token_ids=[],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=pooler_output,
|
|
kv_connector_output=kv_connector_output,
|
|
)
|
|
|
|
def _select_moe_comm_method(self, num_tokens: int,
|
|
with_prefill: bool) -> Optional[MoECommType]:
|
|
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
|
|
are designed for expert parallelism.
|
|
2. If expert parallel is enabled, we need to consider the soc version and the
|
|
number of tokens. This is based on the observation that all-gather is more
|
|
efficient than all-to-all when running on A2.
|
|
|
|
a. For A2, we choose from MC2 and all-gather.
|
|
|
|
b. For A3, we choose from MC2 and all-to-all.
|
|
|
|
In both cases, we use MC2 when the number of tokens is smaller than
|
|
a its capacity threshold.
|
|
|
|
Args:
|
|
num_tokens (int): The number of tokens in the current batch.
|
|
|
|
Raises:
|
|
ValueError: If the soc version is unsupported.
|
|
|
|
Returns:
|
|
MoECommType: The selected MoE communication method.
|
|
"""
|
|
if not is_moe_model(self.vllm_config):
|
|
return None
|
|
|
|
soc_version = get_ascend_soc_version()
|
|
quant_type = getattr(self.vllm_config.model_config.hf_config,
|
|
'moe_quantize', None)
|
|
model_type = self.vllm_config.model_config.hf_config.model_type
|
|
|
|
if not self.parallel_config.enable_expert_parallel:
|
|
moe_comm_type = MoECommType.ALLGATHER
|
|
elif soc_version in {AscendSocVersion.A2}:
|
|
if (num_tokens <= self.mc2_tokens_capacity
|
|
and self.parallel_config.world_size_across_dp >= 16):
|
|
moe_comm_type = MoECommType.MC2
|
|
else:
|
|
# Currently, w4a8_dynamic does not support allgatherep
|
|
if quant_type == "w4a8_dynamic":
|
|
moe_comm_type = MoECommType.ALLTOALL
|
|
else:
|
|
moe_comm_type = MoECommType.ALLGATHER
|
|
|
|
elif soc_version in {AscendSocVersion.A3}:
|
|
moe_comm_type = (MoECommType.MC2
|
|
if num_tokens <= self.mc2_tokens_capacity else
|
|
MoECommType.ALLTOALL)
|
|
else:
|
|
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
|
|
|
if moe_comm_type == MoECommType.ALLGATHER and with_prefill:
|
|
if enable_sp():
|
|
moe_comm_type = MoECommType.ALLGATHER
|
|
else:
|
|
moe_comm_type = MoECommType.NAIVE_MULTICAST
|
|
|
|
# PanguProMoE only supports allgather
|
|
if model_type == "PanguProMoE":
|
|
moe_comm_type = MoECommType.ALLGATHER
|
|
|
|
if is_global_first_rank():
|
|
logger.debug(f"num_tokens: {num_tokens}, "
|
|
f"moe_comm_type: {moe_comm_type}")
|
|
return moe_comm_type
|
|
|
|
@torch.inference_mode()
|
|
def execute_model(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
|
|
with ProfileExecuteDuration().capture_async("prepare input"):
|
|
self._update_states(scheduler_output)
|
|
if not scheduler_output.total_num_scheduled_tokens:
|
|
if not has_kv_transfer_group():
|
|
logger.debug(
|
|
"skip this step for we receive the data from remote disaggregate prefill node"
|
|
)
|
|
# Return empty ModelRunnerOuptut if there's no work to do.
|
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
|
return self.kv_connector_no_forward(scheduler_output)
|
|
|
|
if self.dynamic_eplb:
|
|
self.eplb_updator.forward_before()
|
|
|
|
(attn_metadata, positions, num_scheduled_tokens_np,
|
|
num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens,
|
|
logits_indices, spec_decode_metadata, input_ids, inputs_embeds,
|
|
intermediate_tensors,
|
|
max_query_len) = (self._prepare_inputs(scheduler_output,
|
|
intermediate_tensors))
|
|
|
|
if self.dynamic_eplb:
|
|
self.eplb_updator.take_update_info_from_eplb_process()
|
|
|
|
moe_comm_type = self._select_moe_comm_method(num_input_tokens,
|
|
self.with_prefill)
|
|
|
|
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
|
scheduler_output.total_num_scheduled_tokens
|
|
== self.input_batch.num_reqs * max_query_len)
|
|
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
|
uniform_decode=uniform_decode)
|
|
aclgraph_runtime_mode, batch_descriptor = \
|
|
self.aclgraph_dispatcher.dispatch(batch_descriptor)
|
|
|
|
# Run forward pass
|
|
with ProfileExecuteDuration().capture_async("forward"):
|
|
with set_ascend_forward_context(
|
|
attn_metadata,
|
|
self.vllm_config,
|
|
num_tokens=num_input_tokens,
|
|
num_tokens_across_dp=num_tokens_across_dp,
|
|
with_prefill=self.with_prefill,
|
|
reserved_mc2_mask=self.reserved_mc2_mask,
|
|
moe_comm_type=moe_comm_type,
|
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
|
batch_descriptor=batch_descriptor,
|
|
num_actual_tokens=scheduler_output.
|
|
total_num_scheduled_tokens,
|
|
prefetch_stream=self.prefetch_stream,
|
|
model_instance=self.model,
|
|
weight_prefetch_method=self.weight_prefetch_method):
|
|
self.maybe_setup_kv_connector(scheduler_output)
|
|
|
|
hidden_states = self._generate_process_reqs_hidden_states(
|
|
attn_metadata, self.with_prefill, maybe_padded_num_tokens,
|
|
input_ids, positions, intermediate_tensors, inputs_embeds)
|
|
|
|
self.maybe_wait_for_kv_save()
|
|
finished_sending, finished_recving = self.get_finished_kv_transfer(
|
|
scheduler_output)
|
|
|
|
aux_hidden_states = None
|
|
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
|
|
hidden_states, aux_hidden_states = hidden_states
|
|
|
|
kv_connector_output = KVConnectorOutput(
|
|
finished_sending=finished_sending,
|
|
finished_recving=finished_recving)
|
|
finished_sending = None
|
|
finished_recving = None
|
|
with ProfileExecuteDuration().capture_async("post process"):
|
|
# Broadcast PP output for external_launcher (torchrun)
|
|
# to make sure we are synced across pp ranks
|
|
# TODO: Support overlapping mirco-batches
|
|
# https://github.com/vllm-project/vllm/issues/18019
|
|
broadcast_pp_output = \
|
|
self.parallel_config.distributed_executor_backend \
|
|
== "external_launcher" and len(get_pp_group().ranks) > 0
|
|
if not get_pp_group().is_last_rank:
|
|
# For mid-pipeline stages, return the hidden states.
|
|
if not broadcast_pp_output:
|
|
hidden_states.kv_connector_output = kv_connector_output
|
|
return hidden_states
|
|
assert isinstance(hidden_states, IntermediateTensors)
|
|
get_pp_group().send_tensor_dict(
|
|
hidden_states.tensors, all_gather_group=get_tp_group())
|
|
logits = None
|
|
else:
|
|
if self.input_batch.pooling_params:
|
|
return self._pool(
|
|
hidden_states,
|
|
scheduler_output.total_num_scheduled_tokens,
|
|
num_scheduled_tokens_np, finished_sending,
|
|
finished_recving, kv_connector_output)
|
|
sample_hidden_states = hidden_states[logits_indices]
|
|
logits = self.model.compute_logits(sample_hidden_states)
|
|
if broadcast_pp_output:
|
|
model_output_broadcast_data = {
|
|
"logits": logits.contiguous(),
|
|
} if logits is not None else {}
|
|
model_output_broadcast_data = get_pp_group(
|
|
).broadcast_tensor_dict(model_output_broadcast_data,
|
|
src=len(get_pp_group().ranks) - 1)
|
|
assert model_output_broadcast_data is not None
|
|
logits = model_output_broadcast_data["logits"]
|
|
|
|
# Apply structured output bitmasks if present
|
|
if scheduler_output.grammar_bitmask is not None:
|
|
logits = self.apply_grammar_bitmask(scheduler_output, logits)
|
|
|
|
# Sample the next token and get logprobs if needed.
|
|
sampling_metadata = self.input_batch.sampling_metadata
|
|
if spec_decode_metadata is None:
|
|
if lmhead_tp_enable() and logits is not None:
|
|
logits = logits[:self.input_batch.num_reqs]
|
|
sampler_output = self.sampler(
|
|
logits=logits,
|
|
sampling_metadata=sampling_metadata,
|
|
)
|
|
else:
|
|
if lmhead_tp_enable() and logits is not None:
|
|
logits = logits[:len(spec_decode_metadata.logits_indices)]
|
|
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
|
# creates a new tensor with separate storage from the original
|
|
# logits tensor. This means any in-place operations on bonus_logits
|
|
# won't affect the original logits tensor.
|
|
assert logits is not None
|
|
bonus_logits = logits[
|
|
spec_decode_metadata.bonus_logits_indices]
|
|
sampler_output = self.sampler(
|
|
logits=bonus_logits,
|
|
sampling_metadata=sampling_metadata,
|
|
)
|
|
bonus_token_ids = sampler_output.sampled_token_ids
|
|
|
|
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
|
# separate storage from the original `logits` tensor. Therefore,
|
|
# it is safe to update `target_logits` in place.
|
|
target_logits = logits[
|
|
spec_decode_metadata.target_logits_indices]
|
|
output_token_ids = self.rejection_sampler(
|
|
spec_decode_metadata,
|
|
None, # draft_probs
|
|
target_logits,
|
|
bonus_token_ids,
|
|
sampling_metadata,
|
|
)
|
|
sampler_output.sampled_token_ids = output_token_ids
|
|
if self.need_accepted_tokens:
|
|
self._update_states_after_model_execute(output_token_ids)
|
|
|
|
discard_sampled_tokens_req_indices: list[int] = []
|
|
# TODO(woosuk): The following loop can be slow since it iterates over
|
|
# the requests one by one. Optimize.
|
|
discard_sampled_tokens_req_indices = []
|
|
for i, req_id in enumerate(self.input_batch.req_ids):
|
|
req_state = self.requests[req_id]
|
|
seq_len = (req_state.num_computed_tokens +
|
|
scheduler_output.num_scheduled_tokens[req_id])
|
|
if seq_len < req_state.num_tokens:
|
|
# Ignore the sampled token.
|
|
# Rewind the generator state as if the token was not sampled.
|
|
generator = self.input_batch.generators.get(i)
|
|
if generator is not None:
|
|
generator.set_offset(generator.get_offset() - 4)
|
|
discard_sampled_tokens_req_indices.append(i)
|
|
|
|
# Copy some objects so they don't get modified after returning.
|
|
# This is important when using async scheduling.
|
|
req_ids_output_copy = self.input_batch.req_ids.copy()
|
|
req_id_to_index_output_copy = \
|
|
self.input_batch.req_id_to_index.copy()
|
|
|
|
# NOTE: NPU -> CPU Sync happens here.
|
|
# Move as many CPU operations as possible before this sync point.
|
|
logprobs_tensors = sampler_output.logprobs_tensors
|
|
logprobs_lists = logprobs_tensors.tolists() \
|
|
if logprobs_tensors is not None else None
|
|
|
|
# Compute prompt logprobs if needed.
|
|
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
|
hidden_states[:scheduler_output.total_num_scheduled_tokens],
|
|
scheduler_output,
|
|
)
|
|
|
|
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
|
sampled_token_ids = sampler_output.sampled_token_ids
|
|
if not self.use_async_scheduling:
|
|
# Get the valid generated tokens.
|
|
max_gen_len = sampled_token_ids.shape[-1]
|
|
if max_gen_len == 1:
|
|
# No spec decode tokens.
|
|
valid_sampled_token_ids = sampled_token_ids.tolist()
|
|
else:
|
|
# Includes spec decode tokens.
|
|
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
|
sampled_token_ids,
|
|
self.input_batch.vocab_size,
|
|
)
|
|
# Mask out the sampled tokens that should not be sampled.
|
|
for i in discard_sampled_tokens_req_indices:
|
|
valid_sampled_token_ids[i].clear()
|
|
else:
|
|
valid_sampled_token_ids = []
|
|
invalid_req_indices = list(discard_sampled_tokens_req_indices)
|
|
invalid_req_indices_set = set(invalid_req_indices)
|
|
assert sampled_token_ids.shape[-1] == 1
|
|
|
|
# Cache the sampled tokens on the NPU and avoid CPU sync.
|
|
# These will be copied into input_ids in the next step
|
|
# when preparing inputs.
|
|
self.input_batch.prev_sampled_token_ids = \
|
|
sampled_token_ids
|
|
self.input_batch.prev_sampled_token_ids_invalid_indices = \
|
|
invalid_req_indices_set
|
|
self.input_batch.prev_req_id_to_index = {
|
|
req_id: i
|
|
for i, req_id in enumerate(self.input_batch.req_ids)
|
|
if i not in invalid_req_indices_set
|
|
}
|
|
# Cache the sampled tokens in the model runner, so that the scheduler
|
|
# doesn't need to send them back.
|
|
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
|
# the sampled tokens back, because there's no direct communication
|
|
# between the first-stage worker and the last-stage worker.
|
|
for req_idx in range(num_sampled_tokens):
|
|
if self.use_async_scheduling:
|
|
sampled_ids = [-1] * 1 if \
|
|
req_idx not in invalid_req_indices_set else None
|
|
else:
|
|
sampled_ids = valid_sampled_token_ids[req_idx]
|
|
if not sampled_ids:
|
|
continue
|
|
|
|
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
|
end_idx = start_idx + len(sampled_ids)
|
|
assert end_idx <= self.model_config.max_model_len, (
|
|
"Sampled token IDs exceed the max model length. "
|
|
f"Total number of tokens: {end_idx} > max_model_len: "
|
|
f"{self.model_config.max_model_len}")
|
|
|
|
self.input_batch.token_ids_cpu[req_idx,
|
|
start_idx:end_idx] = sampled_ids
|
|
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
|
self.input_batch.num_tokens[req_idx] = end_idx
|
|
req_id = self.input_batch.req_ids[req_idx]
|
|
req_state = self.requests[req_id]
|
|
req_state.output_token_ids.extend(sampled_ids)
|
|
|
|
if self.speculative_config:
|
|
self._draft_token_ids = self.propose_draft_token_ids(
|
|
valid_sampled_token_ids,
|
|
sampling_metadata,
|
|
scheduler_output,
|
|
spec_decode_metadata,
|
|
positions,
|
|
scheduler_output.total_num_scheduled_tokens,
|
|
hidden_states,
|
|
attn_metadata,
|
|
aux_hidden_states,
|
|
)
|
|
|
|
if has_kv_transfer_group():
|
|
get_kv_transfer_group().clear_connector_metadata()
|
|
|
|
extra_args = ({"kv_connector_output": kv_connector_output})
|
|
|
|
model_runner_output = ModelRunnerOutput(
|
|
req_ids=req_ids_output_copy,
|
|
req_id_to_index=req_id_to_index_output_copy,
|
|
sampled_token_ids=valid_sampled_token_ids,
|
|
logprobs=logprobs_lists,
|
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
|
pooler_output=[],
|
|
**extra_args,
|
|
)
|
|
|
|
durations = ProfileExecuteDuration().pop_captured_sync()
|
|
if durations:
|
|
dr_str = [
|
|
f"[{tag}]:{duration:.2f}ms"
|
|
for tag, duration in durations.items()
|
|
]
|
|
captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill"
|
|
logger.info("Profile execute duration [%s]:%s", captured_name,
|
|
" ".join(dr_str))
|
|
if self.dynamic_eplb:
|
|
self.eplb_updator.forward_end()
|
|
if not self.use_async_scheduling:
|
|
return model_runner_output
|
|
|
|
return AsyncNPUModelRunnerOutput(
|
|
model_runner_output=model_runner_output,
|
|
sampled_token_ids=sampled_token_ids,
|
|
invalid_req_indices=invalid_req_indices,
|
|
async_output_copy_stream=self.async_output_copy_stream,
|
|
)
|
|
|
|
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
|
if self._draft_token_ids is None:
|
|
return None
|
|
req_ids = self.input_batch.req_ids
|
|
if isinstance(self._draft_token_ids, torch.Tensor):
|
|
draft_token_ids = self._draft_token_ids.tolist()
|
|
else:
|
|
draft_token_ids = self._draft_token_ids
|
|
self._draft_token_ids = None
|
|
return DraftTokenIds(req_ids, draft_token_ids)
|
|
|
|
def kv_connector_no_forward(
|
|
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
|
with set_ascend_forward_context(None, self.vllm_config):
|
|
self.maybe_setup_kv_connector(scheduler_output)
|
|
finished_sending, finished_recving = (
|
|
self.get_finished_kv_transfer(scheduler_output))
|
|
# For the case of no forward caused by receiving remote kv,
|
|
# one round of dummy inference is necessary
|
|
# to prevent hang over the collective calls.
|
|
|
|
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
|
output.kv_connector_output = KVConnectorOutput(
|
|
finished_sending=finished_sending,
|
|
finished_recving=finished_recving)
|
|
return output
|
|
|
|
@staticmethod
|
|
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
|
|
# Update KVConnector with the KVConnector metadata forward().
|
|
if has_kv_transfer_group():
|
|
kv_connector = get_kv_transfer_group()
|
|
assert isinstance(kv_connector, KVConnectorBase_V1)
|
|
assert scheduler_output.kv_connector_metadata is not None
|
|
kv_connector.bind_connector_metadata(
|
|
scheduler_output.kv_connector_metadata)
|
|
|
|
kv_connector.start_load_kv(get_forward_context())
|
|
|
|
@staticmethod
|
|
def maybe_wait_for_kv_save() -> None:
|
|
if has_kv_transfer_group():
|
|
get_kv_transfer_group().wait_for_save()
|
|
|
|
@staticmethod
|
|
def get_finished_kv_transfer(
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
|
if has_kv_transfer_group():
|
|
return get_kv_transfer_group().get_finished(
|
|
scheduler_output.finished_req_ids)
|
|
return None, None
|
|
|
|
def _build_dummy_attn_metadata(
|
|
self,
|
|
with_prefill: bool,
|
|
num_reqs: int,
|
|
num_tokens: int,
|
|
max_query_len: int,
|
|
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
|
force_attention: bool = False,
|
|
) -> Optional[dict[str, Any]]:
|
|
attn_metadata: Optional[dict[str, Any]] = None
|
|
|
|
if force_attention or aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
|
assert with_prefill is False, \
|
|
"Full decode graph only supports uniform batch now."
|
|
|
|
attn_metadata = {}
|
|
|
|
seq_lens = self.model_config.max_model_len
|
|
self.seq_lens_np[:num_reqs] = seq_lens
|
|
self.seq_lens_np[num_reqs:] = 0
|
|
|
|
num_computed_tokens_cpu = (
|
|
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
|
|
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
|
self.kv_cache_config.kv_cache_groups):
|
|
block_table_tensor = self.input_batch.block_table[
|
|
kv_cache_group_id].get_device_tensor()
|
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
|
query_start_loc=torch.tensor(
|
|
[0] + self.actual_seq_lengths_q[:num_reqs],
|
|
device=self.device,
|
|
dtype=torch.int32),
|
|
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
|
|
1],
|
|
seq_lens_cpu=self.seq_lens_cpu,
|
|
seq_lens=self.seq_lens_cpu[:num_reqs],
|
|
num_reqs=num_reqs,
|
|
num_actual_tokens=num_tokens,
|
|
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
|
block_table_tensor=block_table_tensor[:num_reqs],
|
|
slot_mapping=self.slot_mapping,
|
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
|
positions=self.positions,
|
|
attn_mask=self.attn_mask,
|
|
spec_attn_mask=self.spec_attn_mask,
|
|
attn_state=self.attn_state,
|
|
max_query_len=max_query_len,
|
|
decode_token_per_req=self.decode_token_per_req,
|
|
cos=self.cos,
|
|
sin=self.sin,
|
|
)
|
|
attn_state = AscendAttentionState.DecodeOnly
|
|
if self.speculative_config and \
|
|
self.speculative_config.method == "deepseek_mtp":
|
|
attn_state = AscendAttentionState.SpecDecoding
|
|
|
|
for attn_group in self.attn_groups[kv_cache_group_id]:
|
|
builder = attn_group.get_metadata_builder()
|
|
attn_metadata_i = builder.build_for_graph_capture(
|
|
common_attn_metadata, attn_state, self.get_model())
|
|
for layer_name in kv_cache_group_spec.layer_names:
|
|
attn_metadata[layer_name] = attn_metadata_i
|
|
|
|
return attn_metadata
|
|
|
|
def _generate_dummy_run_hidden_states(self, with_prefill,
|
|
is_torchair_compile, input_ids,
|
|
positions, attn_metadata, num_tokens,
|
|
intermediate_tensors, inputs_embeds):
|
|
hidden_states = self.model(input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds)
|
|
forward_context = get_forward_context()
|
|
assert forward_context is not None
|
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
|
|
not forward_context.capturing:
|
|
if self.vllm_config.model_config.use_mla:
|
|
# FIXME: Try using `auto_dispatch_capture=True`
|
|
update_mla_attn_params(self.update_stream, forward_context,
|
|
positions.shape[0],
|
|
self.speculative_config)
|
|
else:
|
|
update_attn_params(self.update_stream, forward_context,
|
|
positions.shape[0])
|
|
|
|
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
|
|
hidden_states, _ = hidden_states
|
|
else:
|
|
hidden_states = hidden_states
|
|
return hidden_states
|
|
|
|
@torch.inference_mode()
|
|
def _dummy_run(
|
|
self,
|
|
num_tokens: int,
|
|
with_prefill: bool = False,
|
|
is_torchair_compile: bool = False,
|
|
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
|
force_attention: bool = False,
|
|
uniform_decode: bool = False,
|
|
) -> torch.Tensor:
|
|
# only support eager mode and piecewise graph now
|
|
assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in {
|
|
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
|
}
|
|
|
|
# In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs.
|
|
# If sequence parallelism is enabled, it is essential to ensure that num_tokens is divisible by tp_size.
|
|
if self.use_aclgraph and enable_sp(self.vllm_config):
|
|
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
|
num_tokens = math.ceil(num_tokens / tp_size) * tp_size
|
|
|
|
# Padding for DP
|
|
(num_tokens, num_tokens_across_dp, with_prefill,
|
|
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
|
|
|
|
moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill)
|
|
|
|
# If cudagraph_mode.decode_mode() == FULL and
|
|
# cudagraph_mode.seperate_routine(). This means that we are using
|
|
# different graphs and/or modes for mixed prefill-decode batches vs.
|
|
# uniform decode batches. A uniform decode batch means that all
|
|
# requests have identical query length, except a potential virtual
|
|
# request (shorter) in the batch account for padding.
|
|
# Uniform decode batch could either be common pure decode, where
|
|
# max_query_len == 1, or speculative decode, where
|
|
# max_query_len == 1 + num_spec_decode_tokens.
|
|
|
|
# When setting max_query_len = 1, we switch to and capture the optimized
|
|
# routine of FA2 for pure decode, i.e., Flashdecode + an optimization
|
|
# for GQA/MQA.
|
|
max_query_len = self.uniform_decode_query_len if uniform_decode else \
|
|
num_tokens
|
|
|
|
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
|
# for dummy run with LoRA so that the num_reqs collectively
|
|
# has num_tokens in total.
|
|
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
|
max_num_reqs = self.max_num_reqs
|
|
if uniform_decode:
|
|
num_reqs = cdiv(num_tokens, max_query_len)
|
|
num_scheduled_tokens_list = [max_query_len] * num_reqs
|
|
if num_tokens % max_query_len != 0:
|
|
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
|
|
else:
|
|
if with_prefill:
|
|
num_reqs = num_tokens
|
|
else:
|
|
num_reqs = (num_tokens + self.decode_token_per_req -
|
|
1) // self.decode_token_per_req
|
|
num_reqs = min(num_reqs, max_num_reqs)
|
|
min_tokens_per_req = num_tokens // num_reqs
|
|
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
|
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
|
assert sum(num_scheduled_tokens_list) == num_tokens
|
|
assert len(num_scheduled_tokens_list) == num_reqs
|
|
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
|
dtype=np.int32)
|
|
|
|
# Force dummy run on prefill stage when this node is deemed as kv producer.
|
|
if self.is_kv_producer and not self.is_kv_consumer:
|
|
with_prefill = True
|
|
|
|
if not self.in_profile_run and self.dynamic_eplb:
|
|
self.eplb_updator.forward_before()
|
|
|
|
with self.maybe_dummy_run_with_lora(self.lora_config,
|
|
num_scheduled_tokens):
|
|
if self.is_multimodal_model:
|
|
input_ids = None
|
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
|
else:
|
|
input_ids = self.input_ids[:num_tokens]
|
|
inputs_embeds = None
|
|
|
|
if self.uses_mrope:
|
|
positions = self.mrope_positions[:, :num_tokens]
|
|
else:
|
|
positions = self.positions[:num_tokens]
|
|
|
|
if get_pp_group().is_first_rank:
|
|
intermediate_tensors = None
|
|
else:
|
|
if self.intermediate_tensors is None:
|
|
self.intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors(
|
|
batch_size=num_tokens,
|
|
dtype=self.dtype,
|
|
device=self.device))
|
|
intermediate_tensors = IntermediateTensors({
|
|
k: v[:num_tokens]
|
|
for k, v in self.intermediate_tensors.items()
|
|
})
|
|
|
|
# filter out the valid batch descriptor
|
|
_ag_mode, batch_descriptor = \
|
|
self.aclgraph_dispatcher.dispatch(
|
|
BatchDescriptor(num_tokens=num_tokens,
|
|
uniform_decode=uniform_decode))
|
|
if aclgraph_runtime_mode is not None:
|
|
# we allow forcing NONE when the dispatcher disagrees to support
|
|
# warm ups for aclgraph capture
|
|
assert aclgraph_runtime_mode == CUDAGraphMode.NONE or \
|
|
aclgraph_runtime_mode == _ag_mode, (
|
|
f"Aclgraph runtime mode mismatch at dummy_run. "
|
|
f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}.")
|
|
else:
|
|
aclgraph_runtime_mode = _ag_mode
|
|
|
|
# TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup
|
|
# and not supported in ASCEND now. We could remove it in the future.
|
|
attn_metadata = self._build_dummy_attn_metadata(
|
|
False,
|
|
num_reqs=num_reqs,
|
|
num_tokens=num_tokens,
|
|
max_query_len=max_query_len,
|
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
|
force_attention=force_attention,
|
|
)
|
|
|
|
need_dummy_logits = (not self.in_profile_run
|
|
and lmhead_tp_enable())
|
|
|
|
if need_dummy_logits:
|
|
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
|
|
dummy_indices = torch.zeros(max_num_reqs_across_dp,
|
|
dtype=torch.int32)
|
|
|
|
def dummy_compute_logits(hidden_states):
|
|
return self.model.compute_logits(
|
|
hidden_states[dummy_indices])
|
|
|
|
with set_ascend_forward_context(
|
|
attn_metadata,
|
|
self.vllm_config,
|
|
num_tokens=num_tokens,
|
|
num_tokens_across_dp=num_tokens_across_dp,
|
|
with_prefill=with_prefill,
|
|
in_profile_run=self.in_profile_run,
|
|
reserved_mc2_mask=self.reserved_mc2_mask,
|
|
moe_comm_type=moe_comm_type,
|
|
num_actual_tokens=0,
|
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
|
batch_descriptor=batch_descriptor,
|
|
prefetch_stream=self.prefetch_stream,
|
|
model_instance=self.model,
|
|
weight_prefetch_method=self.weight_prefetch_method):
|
|
hidden_states = self._generate_dummy_run_hidden_states(
|
|
with_prefill, is_torchair_compile, input_ids, positions,
|
|
attn_metadata, num_tokens, intermediate_tensors,
|
|
inputs_embeds)
|
|
if need_dummy_logits:
|
|
dummy_compute_logits(hidden_states)
|
|
|
|
if self.drafter:
|
|
self.drafter.dummy_run(
|
|
num_tokens=num_tokens,
|
|
with_prefill=with_prefill,
|
|
skip_attn=True,
|
|
num_reqs=num_reqs,
|
|
num_tokens_across_dp=num_tokens_across_dp,
|
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
|
batch_descriptor=batch_descriptor)
|
|
if need_dummy_logits:
|
|
dummy_compute_logits(hidden_states)
|
|
if self.in_profile_run and self.dynamic_eplb:
|
|
self.model.clear_all_moe_loads()
|
|
if not self.in_profile_run and self.dynamic_eplb:
|
|
self.eplb_updator.take_update_info_from_eplb_process()
|
|
self.eplb_updator.forward_end()
|
|
return hidden_states
|
|
|
|
@contextmanager
|
|
def set_in_profile_run(self):
|
|
self.in_profile_run = True
|
|
try:
|
|
yield
|
|
finally:
|
|
self.in_profile_run = False
|
|
|
|
def profile_run(self) -> None:
|
|
# Trigger compilation for general shape.
|
|
with self.set_in_profile_run():
|
|
hidden_states = self._dummy_run(self.max_num_tokens,
|
|
with_prefill=True)
|
|
# MC2 will consume additional NPU memory.
|
|
# Therefore, we need to run the MC2 path once here to complete its initialization,
|
|
# allowing vLLM to correctly estimate the maximum memory required.
|
|
if self.max_num_tokens > self.mc2_tokens_capacity and \
|
|
self._select_moe_comm_method(
|
|
self.mc2_tokens_capacity,
|
|
with_prefill=True) == MoECommType.MC2:
|
|
self._dummy_run(self.mc2_tokens_capacity, with_prefill=True)
|
|
|
|
output = None
|
|
if get_pp_group().is_last_rank:
|
|
if self.is_pooling_model:
|
|
output = self._dummy_pooler_run(hidden_states)
|
|
else:
|
|
# For profile, have maximum num_reqs and that collectively have
|
|
# maximum num_tokens.
|
|
min_tokens_per_req = self.max_num_tokens // self.max_num_reqs
|
|
num_scheduled_tokens_list = [min_tokens_per_req
|
|
] * self.max_num_reqs
|
|
num_scheduled_tokens_list[
|
|
-1] += self.max_num_tokens % self.max_num_reqs
|
|
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
|
dtype=np.int32)
|
|
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
|
# TODO: need to rum a dummy sampler for generate task
|
|
hidden_states = hidden_states[logit_indices]
|
|
output = self.model.compute_logits(hidden_states)
|
|
|
|
NPUPlatform.synchronize()
|
|
del hidden_states, output
|
|
self.encoder_cache.clear()
|
|
gc.collect()
|
|
|
|
def _dummy_pooler_run_task(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
task: PoolingTask,
|
|
) -> PoolerOutput:
|
|
num_tokens = hidden_states.shape[0]
|
|
max_num_reqs = self.scheduler_config.max_num_seqs
|
|
num_reqs = min(num_tokens, max_num_reqs)
|
|
min_tokens_per_req = num_tokens // num_reqs
|
|
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
|
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
|
assert sum(num_scheduled_tokens_list) == num_tokens
|
|
assert len(num_scheduled_tokens_list) == num_reqs
|
|
|
|
req_num_tokens = num_tokens // num_reqs
|
|
|
|
dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
|
|
model = cast(VllmModelForPooling, self.get_model())
|
|
dummy_pooling_params = PoolingParams(task=task)
|
|
to_update = model.pooler.get_pooling_updates(task)
|
|
to_update.apply(dummy_pooling_params)
|
|
|
|
dummy_prompt_lens = torch.tensor(
|
|
num_scheduled_tokens_list,
|
|
device="cpu",
|
|
)
|
|
dummy_metadata = PoolingMetadata(
|
|
prompt_lens=dummy_prompt_lens,
|
|
prompt_token_ids=dummy_token_ids,
|
|
pooling_params=[dummy_pooling_params] * num_reqs,
|
|
)
|
|
|
|
dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list,
|
|
device=hidden_states.device)
|
|
|
|
try:
|
|
return model.pooler(hidden_states=hidden_states,
|
|
pooling_metadata=dummy_metadata)
|
|
except RuntimeError as e:
|
|
if 'out of memory' in str(e):
|
|
raise RuntimeError(
|
|
"CUDA out of memory occurred when warming up pooler "
|
|
f"({task=}) with {num_reqs} dummy requests. Please try "
|
|
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
|
|
"initializing the engine.") from e
|
|
else:
|
|
raise e
|
|
|
|
@torch.inference_mode()
|
|
def _dummy_pooler_run(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> PoolerOutput:
|
|
# Find the task that has the largest output for subsequent steps
|
|
output_size = dict[PoolingTask, float]()
|
|
for task in self.get_supported_pooling_tasks():
|
|
# Run a full batch with each task to ensure none of them OOMs
|
|
output = self._dummy_pooler_run_task(hidden_states, task)
|
|
output_size[task] = sum(o.nbytes for o in output)
|
|
del output # Allow GC
|
|
|
|
max_task = max(output_size.items(), key=lambda x: x[1])[0]
|
|
return self._dummy_pooler_run_task(hidden_states, max_task)
|
|
|
|
def eplb_warmup(self):
|
|
if self.dynamic_eplb and not self.is_eplb_warmuped:
|
|
self.is_eplb_warmuped = True
|
|
self.eplb_adaptor = VllmEplbAdaptor(model=self.model)
|
|
self.eplb_loader.set_adator(self.eplb_adaptor)
|
|
self.eplb_updator.set_adaptor(self.eplb_adaptor)
|
|
self.eplb_updator.warm_up_eplb()
|
|
|
|
def load_model(self) -> None:
|
|
logger.info("Starting to load model %s...", self.model_config.model)
|
|
|
|
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
|
self.model = get_model(vllm_config=self.vllm_config)
|
|
if self.dynamic_eplb:
|
|
model_register(self.model, self.model_config)
|
|
if is_310p():
|
|
from vllm.model_executor.layers.linear import (
|
|
MergedColumnParallelLinear, QKVParallelLinear,
|
|
RowParallelLinear)
|
|
for module in self.model.modules():
|
|
if isinstance(module,
|
|
(MergedColumnParallelLinear,
|
|
QKVParallelLinear, RowParallelLinear)):
|
|
module.weight.data = self._convert_torch_format(
|
|
module.weight.data)
|
|
if self.drafter:
|
|
logger.info("Loading drafter model...")
|
|
self.drafter.load_model(self.model)
|
|
if self.drafter.name == SpecDcodeType.EAGLE3:
|
|
self.model.set_aux_hidden_state_layers(
|
|
self.model.get_eagle3_aux_hidden_state_layers())
|
|
|
|
if self.lora_config:
|
|
self.model = self.load_lora_model(self.model, self.vllm_config,
|
|
self.device)
|
|
logger.info("Loading model weights took %.4f GB",
|
|
m.consumed_memory / float(2**30))
|
|
|
|
# wrap the model with full graph wrapper if needed.
|
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
|
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
|
set_graph_params(self.compilation_config.cudagraph_capture_sizes)
|
|
self.model = ACLGraphWrapper(self.model,
|
|
self.vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL)
|
|
|
|
def _convert_torch_format(self, tensor):
|
|
if ACL_FORMAT == ACL_FORMAT_FRACTAL_NZ \
|
|
and not is_enable_nz():
|
|
return tensor
|
|
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
|
|
return tensor
|
|
|
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
"""
|
|
Initialize KV cache based on `kv_cache_config`.
|
|
Args:
|
|
kv_cache_config: Configuration for the KV cache, including the KV
|
|
cache size of each layer
|
|
"""
|
|
kv_cache_config = deepcopy(kv_cache_config)
|
|
self.kv_cache_config = kv_cache_config
|
|
self.may_add_encoder_only_layers_to_kv_cache_config()
|
|
# NOTE(cmq): initialize_attn_backend must before using self.attn_groups
|
|
self.initialize_attn_backend(kv_cache_config)
|
|
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
|
|
# NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`.
|
|
self.need_accepted_tokens = any([
|
|
isinstance(attn_group[0].kv_cache_spec, MambaSpec)
|
|
for attn_group in self.attn_groups
|
|
])
|
|
|
|
self.may_reinitialize_input_batch(kv_cache_config)
|
|
|
|
if self.use_sparse:
|
|
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
|
|
kv_cache_config)
|
|
elif self.model_config.is_deepseek_mla:
|
|
kv_caches = self.initialize_kv_cache_tensors_deepseek_mla(
|
|
kv_cache_config)
|
|
else:
|
|
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
|
|
|
if has_kv_transfer_group():
|
|
get_kv_transfer_group().register_kv_caches(kv_caches)
|
|
|
|
def _align_memory(self, tensor: torch.Tensor,
|
|
alignment: int) -> torch.Tensor:
|
|
data_ptr = tensor.data_ptr()
|
|
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
|
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
|
return tensor[int(offset):]
|
|
|
|
def initialize_kv_cache_tensors_deepseek_sfa(
|
|
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
|
kv_cache_sizes = {}
|
|
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
|
assert len(kv_cache_tensor.shared_by) == 1, (
|
|
"KV cache tensor shared by multiple layers is not supported in "
|
|
"NPU.")
|
|
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
|
|
|
kv_caches: Dict[str, torch.Tensor] = {}
|
|
for group in self._kv_cache_spec_attn_group_iterator():
|
|
kv_cache_spec = group.kv_cache_spec
|
|
attn_backend = group.backend
|
|
for layer_name in group.layer_names:
|
|
if layer_name in self.runner_only_attn_layers:
|
|
continue
|
|
tensor_size = kv_cache_sizes[layer_name]
|
|
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
|
if self.vllm_config.additional_config.get(
|
|
"kv_cache_dtype", None) == 'int8':
|
|
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
|
elif hasattr(
|
|
attn_backend, "get_supported_block_size"
|
|
) and not self.model_config.is_deepseek_mla and not self.use_sparse:
|
|
block_size = attn_backend.get_supported_block_size()[0]
|
|
block_size_chunk = kv_cache_spec.block_size // block_size
|
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
|
num_blocks * block_size_chunk, block_size,
|
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
|
else:
|
|
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
|
dtype = kv_cache_spec.dtype
|
|
|
|
alignment = 2 * 1024 * 1024
|
|
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
|
|
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
|
nope_dim = head_size - rope_dim
|
|
nope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
|
nope_dim)
|
|
rope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
|
rope_dim)
|
|
#### k cache
|
|
# TODO(zzzzwwjj): wait transformers add these params
|
|
k_cache_shape = (num_blocks, block_size, 1, 128)
|
|
if self.vllm_config.kv_transfer_config is None:
|
|
# For no disaggregate pd scenario, allocate kv cache in normal way
|
|
rope_cache = torch.zeros(rope_cache_shape,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
nope_cache = torch.zeros(nope_cache_shape,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
rope_cache = self._convert_torch_format(rope_cache)
|
|
nope_cache = self._convert_torch_format(nope_cache)
|
|
|
|
#### k cache
|
|
k_cache = torch.zeros(k_cache_shape,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
k_cache = self._convert_torch_format(k_cache)
|
|
else:
|
|
|
|
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
|
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
|
|
# we found there are also some exceptions during test, so we manual align those memory here, this part
|
|
# of code may consume 2M * 2 * elem_size memory every layer.
|
|
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
|
|
nope_allocate_shape_alignment = nope_allocate_shape + alignment
|
|
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
|
|
rope_allocate_shape_alignment = rope_allocate_shape + alignment
|
|
|
|
nope_cache = torch.zeros(nope_allocate_shape_alignment,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
rope_cache = torch.zeros(rope_allocate_shape_alignment,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
#### k cache
|
|
# TODO(zzzzwwjj): wait transformers add these params
|
|
k_allocate_shape = num_blocks * block_size * 1 * 128
|
|
k_allocate_shape_alignment = k_allocate_shape + alignment
|
|
k_cache = torch.zeros(k_allocate_shape_alignment,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
|
|
nope_cache = self._align_memory(
|
|
nope_cache,
|
|
alignment)[:nope_allocate_shape].view(nope_cache_shape)
|
|
rope_cache = self._align_memory(
|
|
rope_cache,
|
|
alignment)[:rope_allocate_shape].view(rope_cache_shape)
|
|
k_cache = self._align_memory(
|
|
k_cache,
|
|
alignment)[:k_allocate_shape].view(k_cache_shape)
|
|
|
|
kv_caches[layer_name] = (nope_cache, rope_cache, k_cache)
|
|
bind_kv_cache(kv_caches,
|
|
self.compilation_config.static_forward_context,
|
|
self.kv_caches)
|
|
|
|
return kv_caches
|
|
|
|
def initialize_kv_cache_tensors_deepseek_mla(
|
|
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
|
kv_cache_sizes = {}
|
|
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
|
assert len(kv_cache_tensor.shared_by) == 1, (
|
|
"KV cache tensor shared by multiple layers is not supported in "
|
|
"NPU.")
|
|
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
|
|
|
kv_caches: Dict[str, torch.Tensor] = {}
|
|
for group in self._kv_cache_spec_attn_group_iterator():
|
|
kv_cache_spec = group.kv_cache_spec
|
|
attn_backend = group.backend
|
|
for layer_name in group.layer_names:
|
|
if layer_name in self.runner_only_attn_layers:
|
|
continue
|
|
tensor_size = kv_cache_sizes[layer_name]
|
|
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
|
if self.vllm_config.additional_config.get(
|
|
"kv_cache_dtype", None) == 'int8':
|
|
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
|
elif hasattr(attn_backend, "get_supported_block_size"
|
|
) and not self.model_config.is_deepseek_mla:
|
|
block_size = attn_backend.get_supported_block_size()[0]
|
|
block_size_chunk = kv_cache_spec.block_size // block_size
|
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
|
num_blocks * block_size_chunk, block_size,
|
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
|
else:
|
|
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
|
dtype = kv_cache_spec.dtype
|
|
|
|
alignment = 2 * 1024 * 1024
|
|
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
|
|
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
|
nope_dim = head_size - rope_dim
|
|
nope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
|
nope_dim)
|
|
rope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
|
rope_dim)
|
|
if self.vllm_config.kv_transfer_config is None:
|
|
# For no disaggregate pd scenario, allocate kv cache in normal way
|
|
rope_cache = torch.zeros(rope_cache_shape,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
nope_cache = torch.zeros(nope_cache_shape,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
rope_cache = self._convert_torch_format(rope_cache)
|
|
nope_cache = self._convert_torch_format(nope_cache)
|
|
else:
|
|
|
|
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
|
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
|
|
# we found there are also some exceptions during test, so we manual align those memory here, this part
|
|
# of code may consume 2M * 2 * elem_size memory every layer.
|
|
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
|
|
nope_allocate_shape_alignment = nope_allocate_shape + alignment
|
|
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
|
|
rope_allocate_shape_alignment = rope_allocate_shape + alignment
|
|
|
|
nope_cache = torch.zeros(nope_allocate_shape_alignment,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
rope_cache = torch.zeros(rope_allocate_shape_alignment,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
nope_cache = self._align_memory(
|
|
nope_cache,
|
|
alignment)[:nope_allocate_shape].view(nope_cache_shape)
|
|
rope_cache = self._align_memory(
|
|
rope_cache,
|
|
alignment)[:rope_allocate_shape].view(rope_cache_shape)
|
|
kv_caches[layer_name] = (nope_cache, rope_cache)
|
|
|
|
bind_kv_cache(kv_caches,
|
|
self.compilation_config.static_forward_context,
|
|
self.kv_caches)
|
|
|
|
return kv_caches
|
|
|
|
def initialize_kv_cache_tensors(
|
|
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
|
"""
|
|
Initialize the memory buffer for KV cache.
|
|
|
|
Args:
|
|
kv_cache_config: The KV cache config
|
|
Returns:
|
|
Dict[str, torch.Tensor]: A map between layer names to their
|
|
corresponding memory buffer for KV cache.
|
|
"""
|
|
# init kv cache tensors
|
|
kv_cache_raw_tensors: dict[str, Union[torch.Tensor,
|
|
Optional[torch.Tensor]]] = {}
|
|
# llmdatadist need the addr of cache tensor be aligned with 2M
|
|
alignment = 2 * 1024 * 1024
|
|
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
|
# TODO: REFACTOR ME to sharing hybrid cache
|
|
for idx in range(len(kv_cache_tensor.shared_by)):
|
|
layer_name = kv_cache_tensor.shared_by[idx]
|
|
if "linear_attn" in layer_name:
|
|
# for mamba linear attention
|
|
for layer_name_inner in kv_cache_tensor.shared_by:
|
|
if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \
|
|
layer_name_inner in kv_cache_raw_tensors.keys():
|
|
continue
|
|
if self.vllm_config.kv_transfer_config is None:
|
|
tensor = torch.zeros(kv_cache_tensor.size,
|
|
dtype=torch.int8,
|
|
device=self.device)
|
|
else:
|
|
cache_size_aligned = kv_cache_tensor.size + alignment
|
|
tensor = torch.zeros(cache_size_aligned,
|
|
dtype=torch.int8,
|
|
device=self.device)
|
|
tensor = self._align_memory(
|
|
tensor, alignment)[:kv_cache_tensor.size]
|
|
kv_cache_raw_tensors[layer_name_inner] = tensor
|
|
elif "attn" in layer_name:
|
|
# for other attentions, e.g., self_attn, sliding window attn
|
|
if self.vllm_config.kv_transfer_config is None:
|
|
k_tensor = torch.zeros(kv_cache_tensor.size // 2,
|
|
dtype=torch.int8,
|
|
device=self.device)
|
|
v_tensor = torch.zeros(kv_cache_tensor.size // 2,
|
|
dtype=torch.int8,
|
|
device=self.device)
|
|
else:
|
|
cache_size = kv_cache_tensor.size // 2
|
|
cache_size_aligned = kv_cache_tensor.size // 2 + alignment
|
|
k_tensor = torch.zeros(cache_size_aligned,
|
|
dtype=torch.int8,
|
|
device=self.device)
|
|
v_tensor = torch.zeros(cache_size_aligned,
|
|
dtype=torch.int8,
|
|
device=self.device)
|
|
k_tensor = self._align_memory(k_tensor,
|
|
alignment)[:cache_size]
|
|
v_tensor = self._align_memory(v_tensor,
|
|
alignment)[:cache_size]
|
|
kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor)
|
|
|
|
layer_names = set()
|
|
for group in kv_cache_config.kv_cache_groups:
|
|
for layer_name in group.layer_names:
|
|
if layer_name in self.runner_only_attn_layers:
|
|
continue
|
|
layer_names.add(layer_name)
|
|
assert layer_names == set(kv_cache_raw_tensors.keys(
|
|
)), "Some layers are not correctly initialized"
|
|
|
|
kv_caches: Dict[str, torch.Tensor] = {}
|
|
for group in self._kv_cache_spec_attn_group_iterator():
|
|
kv_cache_spec = group.kv_cache_spec
|
|
attn_backend = group.backend
|
|
for layer_name in group.layer_names:
|
|
if layer_name in self.runner_only_attn_layers:
|
|
continue
|
|
|
|
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
|
# encounter OOM issue
|
|
if isinstance(kv_cache_spec, FullAttentionSpec):
|
|
raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore
|
|
layer_name]
|
|
assert raw_k_tensor is not None
|
|
assert raw_v_tensor is not None
|
|
assert (raw_k_tensor.numel() + raw_v_tensor.numel()
|
|
) % kv_cache_spec.page_size_bytes == 0
|
|
num_blocks = (raw_k_tensor.numel() + raw_v_tensor.numel()
|
|
) // kv_cache_spec.page_size_bytes
|
|
|
|
# `num_blocks` is the number of blocks the model runner can use.
|
|
# `kv_cache_config.num_blocks` is the number of blocks that
|
|
# KVCacheManager may allocate.
|
|
# Since different GPUs may have different number of layers and
|
|
# different memory capacities, `num_blocks` can be different on
|
|
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
|
# the min of all `num_blocks`. Verify it here.
|
|
assert num_blocks >= kv_cache_config.num_blocks
|
|
|
|
if self.vllm_config.additional_config.get(
|
|
"kv_cache_dtype", None) == 'int8':
|
|
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads,
|
|
kv_cache_spec.head_size)
|
|
elif hasattr(attn_backend, "get_supported_block_size"
|
|
) and self.use_hybrid_blocks:
|
|
block_size = attn_backend.get_supported_block_size()[0]
|
|
|
|
block_size_chunk = kv_cache_spec.block_size // block_size
|
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
|
num_blocks * block_size_chunk, block_size,
|
|
kv_cache_spec.num_kv_heads,
|
|
kv_cache_spec.head_size)
|
|
else:
|
|
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads,
|
|
kv_cache_spec.head_size)
|
|
dtype = kv_cache_spec.dtype
|
|
k_cache = raw_k_tensor.view(dtype).view(kv_cache_shape[1:])
|
|
k_cache = self._convert_torch_format(k_cache)
|
|
v_cache = raw_v_tensor.view(dtype).view(kv_cache_shape[1:])
|
|
v_cache = self._convert_torch_format(v_cache)
|
|
kv_caches[layer_name] = (k_cache, v_cache)
|
|
elif isinstance(kv_cache_spec, MambaSpec):
|
|
raw_tensor = kv_cache_raw_tensors[layer_name]
|
|
assert raw_tensor is not None
|
|
assert raw_tensor.numel(
|
|
) % kv_cache_spec.page_size_bytes == 0
|
|
num_blocks = raw_tensor.numel(
|
|
) // kv_cache_spec.page_size_bytes
|
|
|
|
# `num_blocks` is the number of blocks the model runner can use.
|
|
# `kv_cache_config.num_blocks` is the number of blocks that
|
|
# KVCacheManager may allocate.
|
|
# Since different GPUs may have different number of layers and
|
|
# different memory capacities, `num_blocks` can be different on
|
|
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
|
# the min of all `num_blocks`. Verify it here.
|
|
assert num_blocks >= kv_cache_config.num_blocks
|
|
|
|
state_tensors = []
|
|
storage_offset_bytes = 0
|
|
for (shape, dtype) in zip(kv_cache_spec.shapes,
|
|
kv_cache_spec.dtypes):
|
|
dtype_size = get_dtype_size(dtype)
|
|
num_element_per_page = (
|
|
kv_cache_spec.page_size_bytes // dtype_size)
|
|
target_shape = (num_blocks, *shape)
|
|
stride = torch.empty(target_shape).stride()
|
|
target_stride = (num_element_per_page, *stride[1:])
|
|
assert storage_offset_bytes % dtype_size == 0
|
|
tensor = torch.as_strided(
|
|
raw_tensor.view(dtype),
|
|
size=target_shape,
|
|
stride=target_stride,
|
|
storage_offset=storage_offset_bytes // dtype_size,
|
|
)
|
|
state_tensors.append(tensor)
|
|
storage_offset_bytes += stride[0] * dtype_size
|
|
kv_caches[layer_name] = state_tensors
|
|
else:
|
|
raise ValueError("Unknown KV cache spec type.")
|
|
|
|
bind_kv_cache(kv_caches,
|
|
self.compilation_config.static_forward_context,
|
|
self.kv_caches)
|
|
|
|
return kv_caches
|
|
|
|
def may_reinitialize_input_batch(self,
|
|
kv_cache_config: KVCacheConfig) -> None:
|
|
"""
|
|
Re-initialize the input batch if the block sizes are different from
|
|
`[self.cache_config.block_size]`. This usually happens when there
|
|
are multiple KV cache groups.
|
|
|
|
Args:
|
|
kv_cache_config: The KV cache configuration.
|
|
"""
|
|
block_sizes = [
|
|
kv_cache_group.kv_cache_spec.block_size
|
|
for kv_cache_group in kv_cache_config.kv_cache_groups
|
|
if not isinstance(kv_cache_group.kv_cache_spec,
|
|
EncoderOnlyAttentionSpec)
|
|
]
|
|
|
|
# Generate kernel_block_sizes that matches each block_size
|
|
# For attention backends that support virtual block splitting,
|
|
# use the supported block sizes from the backend
|
|
# For other backends (like Mamba), use [0] (no splitting)
|
|
kernel_block_sizes = []
|
|
for kv_cache_group_id, kv_cache_group in enumerate(
|
|
kv_cache_config.kv_cache_groups):
|
|
|
|
if isinstance(kv_cache_group.kv_cache_spec,
|
|
EncoderOnlyAttentionSpec):
|
|
continue
|
|
elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
|
|
# This is an attention backend that supports virtual
|
|
# block splitting. Get the supported block sizes from
|
|
# the backend.
|
|
try:
|
|
attn_groups = self.attn_groups[kv_cache_group_id]
|
|
except IndexError:
|
|
attn_groups = None
|
|
if attn_groups and self.use_hybrid_blocks:
|
|
# Use the backend's supported block size list
|
|
backend = attn_groups[0].backend
|
|
supported_sizes = backend.get_supported_block_size()
|
|
# If no specific sizes supported, use cache config
|
|
# block_size
|
|
kernel_block_size_list = (supported_sizes
|
|
if supported_sizes else
|
|
[self.cache_config.block_size])
|
|
else:
|
|
# Fallback to cache config block_size if no backend found
|
|
kernel_block_size_list = [self.cache_config.block_size]
|
|
kernel_block_sizes.append(kernel_block_size_list)
|
|
else:
|
|
# This is likely Mamba or other non-attention cache,
|
|
# no splitting.
|
|
# NOTE: set kernel_block_sizes to 0 to disable slotmapping computation
|
|
# of mamba block. In this case, BlockTable.block_size will never equal
|
|
# to kernel_block_sizes[0]
|
|
kernel_block_sizes.append([0])
|
|
|
|
if block_sizes != [
|
|
self.cache_config.block_size
|
|
] or kernel_block_sizes != [[self.cache_config.block_size]]:
|
|
assert self.cache_config.cpu_offload_gb == 0, (
|
|
"Cannot re-initialize the input batch when CPU weight "
|
|
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
|
"for more details.")
|
|
self.input_batch = InputBatch(
|
|
max_num_reqs=self.max_num_reqs,
|
|
max_model_len=self.model_config.max_model_len,
|
|
max_num_batched_tokens=self.max_num_tokens,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
vocab_size=self.model_config.get_vocab_size(),
|
|
block_sizes=block_sizes,
|
|
is_spec_decode=bool(self.vllm_config.speculative_config),
|
|
logitsprocs=self.input_batch.logitsprocs,
|
|
is_pooling_model=self.is_pooling_model,
|
|
num_speculative_tokens=(
|
|
self.vllm_config.speculative_config.num_speculative_tokens
|
|
if self.vllm_config.speculative_config else 0),
|
|
kernel_block_sizes=kernel_block_sizes,
|
|
)
|
|
|
|
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
|
"""
|
|
Add encoder-only layers to the KV cache config.
|
|
"""
|
|
block_size = self.vllm_config.cache_config.block_size
|
|
encoder_only_attn_specs: dict[AttentionSpec,
|
|
list[str]] = defaultdict(list)
|
|
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
|
for layer_name, attn_module in attn_layers.items():
|
|
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
|
attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=attn_module.num_kv_heads,
|
|
head_size=attn_module.head_size,
|
|
dtype=self.kv_cache_dtype)
|
|
encoder_only_attn_specs[attn_spec].append(layer_name)
|
|
self.runner_only_attn_layers.add(layer_name)
|
|
if len(encoder_only_attn_specs) > 0:
|
|
assert len(
|
|
encoder_only_attn_specs
|
|
) == 1, "Only support one encoder-only attention spec now"
|
|
spec, layer_names = encoder_only_attn_specs.popitem()
|
|
self.kv_cache_config.kv_cache_groups.append(
|
|
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
|
|
|
|
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
|
"""
|
|
Initialize the attention backends and attention metadata builders.
|
|
"""
|
|
assert len(self.attn_groups) == 0, \
|
|
"Attention backends are already initialized"
|
|
|
|
class AttentionGroupKey(NamedTuple):
|
|
attn_backend: type[AttentionBackend]
|
|
kv_cache_spec: KVCacheSpec
|
|
|
|
def get_attn_backends_for_group(
|
|
kv_cache_group_spec: KVCacheGroupSpec,
|
|
) -> dict[AttentionGroupKey, list[str]]:
|
|
layers = get_layers_from_vllm_config(
|
|
self.vllm_config, AttentionLayerBase,
|
|
kv_cache_group_spec.layer_names)
|
|
attn_backends = {}
|
|
attn_backend_layers = defaultdict(list)
|
|
# Dedupe based on full class name; this is a bit safer than
|
|
# using the class itself as the key because when we create dynamic
|
|
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
|
|
# they are cached correctly, there will be different objects per
|
|
# layer.
|
|
for layer_name in kv_cache_group_spec.layer_names:
|
|
attn_backend = layers[layer_name].get_attn_backend()
|
|
full_cls_name = attn_backend.full_cls_name()
|
|
layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
|
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
|
|
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
|
|
layer_name]
|
|
key = (full_cls_name, layer_kv_cache_spec)
|
|
attn_backends[key] = AttentionGroupKey(attn_backend,
|
|
layer_kv_cache_spec)
|
|
attn_backend_layers[key].append(layer_name)
|
|
return {
|
|
attn_backends[k]: v
|
|
for k, v in attn_backend_layers.items()
|
|
}
|
|
|
|
def create_attn_groups(
|
|
attn_backends_map: dict[AttentionBackend, list[str]],
|
|
) -> list[AttentionGroup]:
|
|
attn_groups: list[AttentionGroup] = []
|
|
for (attn_backend,
|
|
kv_cache_spec), layer_names in attn_backends_map.items():
|
|
attn_metadata_builders = []
|
|
attn_metadata_builders.append(attn_backend.get_builder_cls()(
|
|
kv_cache_spec,
|
|
layer_names,
|
|
self.vllm_config,
|
|
self.device,
|
|
))
|
|
attn_group = AttentionGroup(attn_backend,
|
|
attn_metadata_builders,
|
|
layer_names, kv_cache_spec)
|
|
attn_groups.append(attn_group)
|
|
return attn_groups
|
|
|
|
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
|
attn_backends = get_attn_backends_for_group( # type: ignore
|
|
kv_cache_group_spec)
|
|
self.attn_groups.append(create_attn_groups(attn_backends))
|
|
|
|
# Calculate reorder batch threshold (if needed)
|
|
self.calculate_reorder_batch_threshold()
|
|
|
|
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
|
|
return itertools.chain.from_iterable(self.attn_groups)
|
|
|
|
def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]:
|
|
if not self.kv_cache_config.kv_cache_groups:
|
|
return
|
|
for attn_groups in self.attn_groups:
|
|
yield from attn_groups
|
|
|
|
def calculate_reorder_batch_threshold(self) -> None:
|
|
"""
|
|
Check that if any backends reorder batches; that the reordering
|
|
is compatible (e.g., decode threshold is the same)
|
|
"""
|
|
for group in self._attn_group_iterator():
|
|
attn_metadata_builder_i = group.get_metadata_builder()
|
|
if hasattr(attn_metadata_builder_i, "reorder_batch_threshold"):
|
|
# check that if any backends reorder batches; that the reordering
|
|
# is compatible (e.g., decode threshold is the same)
|
|
reorder_batch_threshold_i = (
|
|
attn_metadata_builder_i.reorder_batch_threshold)
|
|
if reorder_batch_threshold_i is not None:
|
|
if self.reorder_batch_threshold is not None:
|
|
if reorder_batch_threshold_i != \
|
|
self.reorder_batch_threshold:
|
|
raise ValueError(
|
|
f"Attention backend reorders decodes with "
|
|
f"threshold {reorder_batch_threshold_i} but other "
|
|
f"backend uses threshold "
|
|
f"{self.reorder_batch_threshold}")
|
|
else:
|
|
self.reorder_batch_threshold = reorder_batch_threshold_i
|
|
|
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
|
"""
|
|
Generates the KVCacheSpec by parsing the kv cache format from each
|
|
Attention module in the static forward context.
|
|
Returns:
|
|
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
|
format. Layers that do not need KV cache are not included.
|
|
"""
|
|
|
|
block_size = self.vllm_config.cache_config.block_size
|
|
use_mla = self.vllm_config.model_config.use_mla
|
|
use_sparse = self.use_sparse
|
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
|
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
|
for layer_name, attn_module in attn_layers.items():
|
|
if (kv_tgt_layer :=
|
|
attn_module.kv_sharing_target_layer_name) is not None:
|
|
# The layer doesn't need its own KV cache and will use that of
|
|
# the target layer. We skip creating a KVCacheSpec for it, so
|
|
# that KV cache management logic will act as this layer does
|
|
# not exist, and doesn't allocate KV cache for the layer. This
|
|
# enables the memory saving of cross-layer kv sharing, allowing
|
|
# a given amount of memory to accommodate longer context lengths
|
|
# or enable more requests to be processed simultaneously.
|
|
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
|
continue
|
|
if isinstance(attn_module, AscendMultiHeadLatentAttention):
|
|
continue
|
|
|
|
# TODO: Support other attention modules, e.g., cross-attention
|
|
# TODO(lucas): move the attention specs into the model layers like
|
|
# the attention backends
|
|
if attn_module.attn_type == AttentionType.DECODER:
|
|
if use_mla and not use_sparse:
|
|
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=attn_module.num_kv_heads,
|
|
head_size=attn_module.head_size,
|
|
dtype=self.kv_cache_dtype,
|
|
cache_dtype_str=self.cache_config.cache_dtype)
|
|
else:
|
|
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
|
# using DSA. Fix the spec in vLLM is a finnal way.
|
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=attn_module.num_kv_heads,
|
|
head_size=attn_module.head_size,
|
|
dtype=self.kv_cache_dtype)
|
|
elif attn_module.attn_type in (AttentionType.ENCODER,
|
|
AttentionType.ENCODER_ONLY):
|
|
# encoder-only attention does not need KV cache.
|
|
continue
|
|
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
|
raise NotImplementedError
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown attention type: {attn_module.attn_type}")
|
|
|
|
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
|
|
if len(mamba_layers) > 0:
|
|
if (self.vllm_config.speculative_config is not None
|
|
and self.vllm_config.model_config.hf_config.model_type
|
|
not in ["qwen3_next"]):
|
|
raise NotImplementedError(
|
|
"Mamba with speculative decoding is not supported yet.")
|
|
if self.vllm_config.cache_config.enable_prefix_caching:
|
|
raise NotImplementedError(
|
|
"Prefix caching is not supported for Mamba yet.")
|
|
max_model_len = self.vllm_config.model_config.max_model_len
|
|
|
|
page_size_padded = (
|
|
self.vllm_config.cache_config.mamba_page_size_padded)
|
|
|
|
# Set block_size to max_model_len, so that mamba model will always
|
|
# have only one block in the KV cache.
|
|
for layer_name, mamba_module in mamba_layers.items():
|
|
kv_cache_spec[layer_name] = MambaSpec(
|
|
shapes=mamba_module.get_state_shape(),
|
|
dtypes=mamba_module.get_state_dtype(),
|
|
block_size=max_model_len,
|
|
page_size_padded=page_size_padded,
|
|
mamba_type=mamba_module.mamba_type,
|
|
num_speculative_blocks=(
|
|
self.speculative_config.num_speculative_tokens
|
|
if self.speculative_config else 0),
|
|
)
|
|
|
|
return kv_cache_spec
|
|
|
|
def initialize_aclgraph_capture(self) -> None:
|
|
min_ag_support = AttentionCGSupport.ALWAYS
|
|
min_ag_builder_name = None
|
|
|
|
for attn_group in self._attn_group_iterator():
|
|
builder = attn_group.get_metadata_builder()
|
|
if builder.aclgraph_support.value < min_ag_support.value:
|
|
min_ag_support = builder.aclgraph_support
|
|
min_ag_builder_name = builder.__class__.__name__
|
|
|
|
# This is an imitation of compilation_config.splitting_ops_contain_attention()
|
|
splitting_ops_contain_attention = (
|
|
self.compilation_config.splitting_ops is not None
|
|
and all(op in self.compilation_config.splitting_ops for op in [
|
|
"vllm.unified_ascend_attention_with_output",
|
|
"vllm.mla_forward",
|
|
]))
|
|
|
|
# Flexible resolve the aclgraph mode
|
|
aclgraph_mode = self.compilation_config.cudagraph_mode
|
|
# check graph for mixed batch is supported
|
|
if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \
|
|
and min_ag_support != AttentionCGSupport.ALWAYS:
|
|
msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported "
|
|
f"with {min_ag_builder_name} backend (support: "
|
|
f"{min_ag_support})")
|
|
if min_ag_support == AttentionCGSupport.NEVER:
|
|
# if not supported any full graphs, just raise it.
|
|
msg += "; please try cudagraph_mode=PIECEWISE, and "\
|
|
"make sure compilation level is piecewise"
|
|
raise ValueError(msg)
|
|
|
|
# attempt to resolve the full graph related mode
|
|
if splitting_ops_contain_attention:
|
|
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
|
|
aclgraph_mode = self.compilation_config.cudagraph_mode = (
|
|
CUDAGraphMode.FULL_AND_PIECEWISE)
|
|
else:
|
|
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
|
|
aclgraph_mode = self.compilation_config.cudagraph_mode = (
|
|
CUDAGraphMode.FULL_DECODE_ONLY)
|
|
logger.warning(msg)
|
|
|
|
# double check that we can support full graph if they are requested
|
|
# even after automatic downgrades
|
|
if aclgraph_mode.has_full_cudagraphs() \
|
|
and min_ag_support == AttentionCGSupport.NEVER:
|
|
raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not "
|
|
f"supported with {min_ag_builder_name} backend ("
|
|
f"support:{min_ag_support}) "
|
|
"; please try cudagraph_mode=PIECEWISE, "
|
|
"and make sure compilation level is piecewise")
|
|
|
|
self.aclgraph_dispatcher.initialize_cudagraph_keys(
|
|
self.compilation_config.cudagraph_mode,
|
|
self.uniform_decode_query_len)
|
|
|
|
def _capture_aclgraphs(self, compilation_cases: list[int],
|
|
aclgraph_runtime_mode: CUDAGraphMode,
|
|
uniform_decode: bool):
|
|
assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
|
|
aclgraph_runtime_mode in [CUDAGraphMode.FULL,
|
|
CUDAGraphMode.PIECEWISE]
|
|
|
|
# Only rank 0 should print progress bar during capture
|
|
if is_global_first_rank():
|
|
logger.info(
|
|
"Starting to capture ACL graphs for cases: %s, "
|
|
"mode: %s, uniform_decode: %s", compilation_cases,
|
|
aclgraph_runtime_mode.name, uniform_decode)
|
|
compilation_cases = tqdm(
|
|
compilation_cases,
|
|
disable=not self.load_config.use_tqdm_on_load,
|
|
desc="Capturing ACL graphs ({}, {})".format(
|
|
"decode" if uniform_decode else "mixed prefill-decode",
|
|
aclgraph_runtime_mode.name))
|
|
# We skip EPLB here since we don't want to record dummy metrics
|
|
for num_tokens in compilation_cases:
|
|
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
|
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
|
# But be careful, warm up with `NONE`is orthogonal to
|
|
# if we want to warm up attention or not. This is
|
|
# different from the case where `FULL` implies capture
|
|
# attention while `PIECEWISE` implies no attention.
|
|
force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL)
|
|
self._dummy_run(num_tokens,
|
|
aclgraph_runtime_mode=CUDAGraphMode.NONE,
|
|
force_attention=force_attention,
|
|
uniform_decode=uniform_decode)
|
|
self._dummy_run(num_tokens,
|
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
|
force_attention=force_attention,
|
|
uniform_decode=uniform_decode)
|
|
|
|
def _capture_model(self):
|
|
if not self.use_aclgraph:
|
|
logger.warning(
|
|
"Skipping ACL graph capture. To turn on ACL graph capture, "
|
|
"ensure `aclraph_mode` was not manually set to `NONE`")
|
|
return
|
|
else:
|
|
self.initialize_aclgraph_capture()
|
|
|
|
set_cudagraph_capturing_enabled(True)
|
|
# Trigger ACL graph capture for specific shapes.
|
|
# Capture the large shapes first so that the smaller shapes
|
|
# can reuse the memory pool allocated for the large shapes.
|
|
with graph_capture(device=self.device):
|
|
aclgraph_mode = self.compilation_config.cudagraph_mode
|
|
if aclgraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
|
aclgraph_runtime_mode = aclgraph_mode.mixed_mode()
|
|
|
|
compilation_cases = list(reversed(self.aclgraph_batch_sizes))
|
|
|
|
try:
|
|
self._capture_aclgraphs(
|
|
compilation_cases,
|
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
|
uniform_decode=False)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
error_code = '0x7020023'
|
|
pattern = r'retCode=([^,\s\.]+)'
|
|
match = re.search(pattern, error_msg)
|
|
if match:
|
|
retCode = match.group(1)
|
|
# Determine whether the error message is caused by stream capture failure.
|
|
if match and retCode == error_code:
|
|
logger.error(
|
|
f"ACLgraph sizes capture fail: {type(e).__name__}:\n"
|
|
"ACLgraph has insufficient available streams to capture the configured number of sizes. "
|
|
"Please verify both the availability of adequate streams and the appropriateness of the configured size count.\n\n"
|
|
"Recommended solutions:\n"
|
|
"1. Manually configure the compilation_config parameter "
|
|
"with a reduced set of sizes: '{\"cudagraph_capture_sizes\":[size1, size2, size3, ...]}'.\n"
|
|
"2. Utilize ACLgraph's full graph mode as an alternative to the piece-wise approach.\n\n"
|
|
f"{str(e)}")
|
|
raise
|
|
|
|
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
|
aclgraph_mode.separate_routine():
|
|
max_num_tokens = self.scheduler_config.max_num_seqs * \
|
|
self.uniform_decode_query_len
|
|
decode_cudagraph_batch_sizes = [
|
|
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
|
|
and x >= self.uniform_decode_query_len
|
|
]
|
|
compilation_cases_decode = list(
|
|
reversed(decode_cudagraph_batch_sizes))
|
|
self._capture_aclgraphs(
|
|
compilation_cases=compilation_cases_decode,
|
|
aclgraph_runtime_mode=CUDAGraphMode.FULL,
|
|
uniform_decode=True)
|
|
|
|
# Disable aclgraph capturing globally, so any unexpected aclgraph
|
|
# capturing will be detected and raise an error after here.
|
|
# Note: We don't put it into graph_capture context manager because
|
|
# we may doing lazy capturing in future that still allows capturing
|
|
# after here.
|
|
set_cudagraph_capturing_enabled(False)
|
|
|
|
def capture_model(self) -> None:
|
|
|
|
compilation_counter.num_gpu_runner_capture_triggers += 1
|
|
|
|
start_time = time.perf_counter()
|
|
start_free_npu_memory = torch.npu.mem_get_info()[0]
|
|
|
|
self._capture_model()
|
|
|
|
end_time = time.perf_counter()
|
|
end_free_npu_memory = torch.npu.mem_get_info()[0]
|
|
elapsed_time = end_time - start_time
|
|
npu_graph_size = start_free_npu_memory - end_free_npu_memory
|
|
# This usually takes 5~20 seconds.
|
|
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
|
elapsed_time, npu_graph_size / (1 << 30))
|
|
|
|
def _get_prompt_logprobs_dict(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> dict[str, Optional[LogprobsTensors]]:
|
|
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
|
|
if not num_prompt_logprobs_dict:
|
|
return {}
|
|
|
|
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
|
|
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
|
|
|
|
# Since prompt logprobs are a rare feature, prioritize simple,
|
|
# maintainable loop over optimal performance.
|
|
completed_prefill_reqs = []
|
|
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
|
|
|
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
|
|
|
# Get metadata for this request.
|
|
request = self.requests[req_id]
|
|
num_prompt_tokens = len(request.prompt_token_ids)
|
|
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
|
|
self.device, non_blocking=True)
|
|
|
|
# Set up target LogprobsTensors object.
|
|
logprobs_tensors = in_progress_dict.get(req_id)
|
|
if not logprobs_tensors:
|
|
# Create empty logprobs CPU tensors for the entire prompt.
|
|
# If chunked, we'll copy in slice by slice.
|
|
logprobs_tensors = LogprobsTensors.empty_cpu(
|
|
num_prompt_tokens - 1, num_prompt_logprobs + 1)
|
|
in_progress_dict[req_id] = logprobs_tensors
|
|
|
|
# Determine number of logits to retrieve.
|
|
start_idx = request.num_computed_tokens
|
|
start_tok = start_idx + 1
|
|
num_remaining_tokens = num_prompt_tokens - start_tok
|
|
if num_tokens <= num_remaining_tokens:
|
|
# This is a chunk, more tokens remain.
|
|
# In the == case, there are no more prompt logprobs to produce
|
|
# but we want to defer returning them to the next step where we
|
|
# have new generated tokens to return.
|
|
num_logits = num_tokens
|
|
else:
|
|
# This is the last chunk of prompt tokens to return.
|
|
num_logits = num_remaining_tokens
|
|
completed_prefill_reqs.append(req_id)
|
|
prompt_logprobs_dict[req_id] = logprobs_tensors
|
|
|
|
if num_logits <= 0:
|
|
# This can happen for the final chunk if we prefilled exactly
|
|
# (num_prompt_tokens - 1) tokens for this request in the prior
|
|
# step. There are no more prompt logprobs to produce.
|
|
continue
|
|
|
|
# Get the logits corresponding to this req's prompt tokens.
|
|
# If this is a partial request (i.e. chunked prefill),
|
|
# then there is prompt logprob generated for each index.
|
|
req_idx = self.input_batch.req_id_to_index[req_id]
|
|
offset = self.query_start_loc_np[req_idx].item()
|
|
prompt_hidden_states = hidden_states[offset:offset + num_logits]
|
|
logits = self.model.compute_logits(prompt_hidden_states)
|
|
|
|
# Get the "target" tokens for each index. For prompt at index i,
|
|
# the token at prompt index i+1 is the "sampled" token we want
|
|
# to gather the logprob for.
|
|
tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]
|
|
|
|
# Compute prompt logprobs.
|
|
logprobs = self.sampler.compute_logprobs(logits)
|
|
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
|
|
logprobs, num_prompt_logprobs, tgt_token_ids)
|
|
|
|
# Transfer NPU->CPU async.
|
|
chunk_slice = slice(start_idx, start_idx + num_logits)
|
|
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
|
|
token_ids, non_blocking=True)
|
|
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
|
|
non_blocking=True)
|
|
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
|
|
ranks, non_blocking=True)
|
|
|
|
# Remove requests that have completed prefill from the batch
|
|
# num_prompt_logprobs_dict.
|
|
for req_id in completed_prefill_reqs:
|
|
del num_prompt_logprobs_dict[req_id]
|
|
del in_progress_dict[req_id]
|
|
|
|
# Must synchronize the non-blocking NPU->CPU transfers.
|
|
if prompt_logprobs_dict:
|
|
torch.npu.synchronize()
|
|
|
|
return prompt_logprobs_dict
|
|
|
|
def get_supported_pooling_tasks(self):
|
|
model = self.get_model()
|
|
if not is_pooling_model(model):
|
|
return []
|
|
|
|
return list(model.pooler.get_supported_tasks())
|
|
|
|
def _build_drafter_prepare_inputs_torchair_param(self):
|
|
return False
|