RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
The metadata data class contains an excessive number of variables. We
will inherit the metadata of the community and simultaneously remove
some variables that are no longer needed at present.
Todo:
1. remove attn_state partly.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
3384 lines
166 KiB
Python
3384 lines
166 KiB
Python
#
|
||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||
# Copyright 2023 The vLLM team.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
# This file is a part of the vllm-ascend project.
|
||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||
#
|
||
|
||
import math
|
||
import time
|
||
from collections import defaultdict
|
||
from contextlib import contextmanager, nullcontext
|
||
from copy import copy, deepcopy
|
||
from dataclasses import dataclass
|
||
from multiprocessing import Manager
|
||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
|
||
|
||
import numpy as np
|
||
import regex as re
|
||
import torch
|
||
import torch.distributed as dist
|
||
import torch.nn as nn
|
||
from tqdm import tqdm # type: ignore
|
||
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
||
from vllm.attention.layer import Attention, MLAAttention
|
||
from vllm.attention.selector import get_attn_backend
|
||
from vllm.compilation.counter import compilation_counter
|
||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
||
get_layers_from_vllm_config)
|
||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||
tensor_model_parallel_all_gather)
|
||
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
|
||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||
has_kv_transfer_group)
|
||
from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group,
|
||
get_pcp_group, get_pp_group,
|
||
get_tp_group,
|
||
is_global_first_rank)
|
||
from vllm.forward_context import get_forward_context
|
||
from vllm.logger import logger
|
||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||
from vllm.model_executor.model_loader import get_model
|
||
from vllm.sequence import IntermediateTensors
|
||
from vllm.utils.import_utils import LazyLoader
|
||
from vllm.utils.math_utils import cdiv
|
||
from vllm.utils.mem_utils import DeviceMemoryProfiler
|
||
from vllm.utils.torch_utils import get_dtype_size
|
||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||
CommonAttentionMetadata)
|
||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||
EncoderOnlyAttentionSpec,
|
||
FullAttentionSpec, KVCacheConfig,
|
||
KVCacheGroupSpec, KVCacheSpec,
|
||
MambaSpec, MLAAttentionSpec,
|
||
UniformTypeKVCacheSpecs)
|
||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||
LogprobsLists, LogprobsTensors, ModelRunnerOutput,
|
||
SamplerOutput,
|
||
make_empty_encoder_model_runner_output)
|
||
from vllm.v1.sample.metadata import SamplingMetadata
|
||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
||
from vllm.v1.worker.gpu_model_runner import (AsyncGPUModelRunnerOutput,
|
||
GPUModelRunner)
|
||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
|
||
from vllm.v1.worker.utils import AttentionGroup
|
||
|
||
import vllm_ascend.envs as envs_ascend
|
||
from vllm_ascend.ascend_config import get_ascend_config
|
||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||
AscendPrefillContextParallelMetadata)
|
||
# yapf conflicts with isort for this block
|
||
# yapf: disable
|
||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||
set_graph_params,
|
||
set_mtp_graph_params,
|
||
update_attn_dcp_pcp_params,
|
||
update_attn_params,
|
||
update_mla_attn_dcp_pcp_params,
|
||
update_mla_attn_params)
|
||
# yapf: enable
|
||
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
||
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
|
||
D2DExpertWeightLoader
|
||
from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils
|
||
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
||
from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
||
from vllm_ascend.eplb.utils import model_register
|
||
from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin
|
||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
||
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
||
from vllm_ascend.sample.sampler import AscendSampler
|
||
from vllm_ascend.spec_decode import get_spec_decode_method
|
||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||
from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration,
|
||
enable_sp, get_ascend_device_type, is_moe_model,
|
||
lmhead_tp_enable, maybe_trans_nz,
|
||
vllm_version_is)
|
||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||
|
||
from vllm_ascend.ascend_forward_context import ( # isort: skip
|
||
MoECommType, get_mc2_tokens_capacity, select_moe_comm_method,
|
||
set_ascend_forward_context, set_mc2_mask, set_mc2_tokens_capacity)
|
||
|
||
if TYPE_CHECKING:
|
||
import xgrammar as xgr # type: ignore[import-untyped]
|
||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||
else:
|
||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||
|
||
import torch_npu
|
||
|
||
# if true, allow tensor initialization and casting with internal format (e.g., NZ)
|
||
torch.npu.config.allow_internal_format = True
|
||
|
||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||
torch_npu.npu.set_compile_mode(jit_compile=False)
|
||
|
||
|
||
@dataclass
|
||
class GraphCaptureContext:
|
||
stream: torch.npu.Stream
|
||
|
||
|
||
@contextmanager
|
||
def graph_capture(device: torch.device):
|
||
"""
|
||
`graph_capture` is a context manager which should surround the code that
|
||
is capturing the NPU graph. Its main purpose is to ensure that the
|
||
some operations will be run after the graph is captured, before the graph
|
||
is replayed. It returns a `GraphCaptureContext` object which contains the
|
||
necessary data for the graph capture. Currently, it only contains the
|
||
stream that the graph capture is running on. This stream is set to the
|
||
current NPU stream when the context manager is entered and reset to the
|
||
default stream when the context manager is exited. This is to ensure that
|
||
the graph capture is running on a separate stream from the default stream,
|
||
in order to explicitly distinguish the kernels to capture
|
||
from other kernels possibly launched on background in the default stream.
|
||
"""
|
||
graph_capture_context = GraphCaptureContext(
|
||
torch.npu.Stream(device=device))
|
||
stream = graph_capture_context.stream
|
||
|
||
# we use nullcontext now
|
||
maybe_ca_context = nullcontext()
|
||
|
||
# ensure all initialization operations complete before attempting to
|
||
# capture the graph on another stream
|
||
curr_stream = torch.npu.current_stream()
|
||
if curr_stream != stream:
|
||
stream.wait_stream(curr_stream)
|
||
|
||
with torch.npu.stream(stream), maybe_ca_context:
|
||
yield graph_capture_context
|
||
|
||
|
||
class ExecuteModelState(NamedTuple):
|
||
"""Ephemeral cached state transferred between execute_model() and
|
||
sample_tokens(), after execute_model() returns None."""
|
||
|
||
scheduler_output: "SchedulerOutput"
|
||
logits: torch.Tensor
|
||
spec_decode_metadata: SpecDecodeMetadata | None
|
||
hidden_states: torch.Tensor
|
||
sample_hidden_states: torch.Tensor
|
||
aux_hidden_states: list[torch.Tensor] | None
|
||
attn_metadata: dict[str, Any]
|
||
positions: torch.Tensor
|
||
|
||
|
||
class NPUModelRunner(GPUModelRunner):
|
||
|
||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||
with _torch_cuda_wrapper():
|
||
super().__init__(vllm_config, device)
|
||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||
try:
|
||
self.dcp_size = get_dcp_group().world_size
|
||
self.dcp_rank = get_dcp_group().rank_in_group
|
||
self.pcp_size = get_pcp_group().world_size
|
||
self.pcp_rank = get_pcp_group(
|
||
).rank_in_group if self.pcp_size > 1 else 0
|
||
except Exception:
|
||
self.dcp_size = 1
|
||
self.dcp_rank = 0
|
||
self.pcp_size = 1
|
||
self.pcp_rank = 0
|
||
if self.pcp_size > 1:
|
||
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
|
||
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
|
||
self.prefetch_stream = torch.npu.Stream(device=device)
|
||
else:
|
||
self.prefetch_stream = None
|
||
self.sampler = AscendSampler()
|
||
self.attn_mask = None
|
||
self.attn_state = None
|
||
|
||
# Ascend-specific configurations
|
||
self.ascend_config = get_ascend_config()
|
||
self.weight_prefetch_method = WeightPrefetchMethod(
|
||
self.ascend_config.weight_prefetch_config)
|
||
# Dump / PrecisionDebugger configuration now comes from AscendConfig
|
||
dump_cfg = self.ascend_config.dump_config
|
||
self.dump_enable = dump_cfg.enable_dump
|
||
self.debugger = None
|
||
if self.dump_enable:
|
||
if self.model_config.enforce_eager:
|
||
from msprobe.pytorch import PrecisionDebugger
|
||
self.debugger = PrecisionDebugger(dump_cfg.config_path)
|
||
else:
|
||
raise RuntimeError(
|
||
"Dumping/debugging only works in eager mode.")
|
||
# use_hybrid_blocks: if hybrid blocks is used.
|
||
self.use_hybrid_blocks: bool = False
|
||
self.need_accepted_tokens: bool = False
|
||
|
||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||
self.block_size = vllm_config.cache_config.block_size
|
||
# Set up Attention
|
||
self.use_sparse = hasattr(self.vllm_config.model_config.hf_config,
|
||
"index_topk")
|
||
if vllm_version_is('0.12.0'):
|
||
self.attn_backend = get_attn_backend(
|
||
0,
|
||
self.dtype,
|
||
None,
|
||
self.block_size,
|
||
use_mla=self.model_config.use_mla,
|
||
use_sparse=self.use_sparse)
|
||
else:
|
||
self.attn_backend = get_attn_backend(
|
||
0,
|
||
self.dtype,
|
||
None,
|
||
self.block_size,
|
||
use_mla=self.model_config.use_mla,
|
||
use_sparse=self.use_sparse,
|
||
use_mm_prefix=self.model_config is not None
|
||
and self.model_config.is_mm_prefix_lm)
|
||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||
|
||
self._set_up_drafter()
|
||
|
||
# kv role
|
||
self.is_kv_producer = False
|
||
self.is_kv_consumer = False
|
||
if vllm_config.kv_transfer_config is not None:
|
||
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
|
||
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
|
||
|
||
set_cos_and_sin(vllm_config, self.max_num_reqs,
|
||
self.uniform_decode_query_len, self.dtype, self.device)
|
||
set_mc2_tokens_capacity(vllm_config, self.max_num_reqs,
|
||
self.uniform_decode_query_len)
|
||
set_mc2_mask(vllm_config, self.device)
|
||
self.pcp_allgather_restore_idx = torch.zeros(
|
||
self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs,
|
||
dtype=torch.int32,
|
||
device=self.device)
|
||
self.cp_kv_recover_idx_for_chunk: List[List[int]] = [
|
||
[] for _ in range(self.pcp_size)
|
||
]
|
||
|
||
self.num_pcp_pads = torch.zeros(self.max_num_reqs, dtype=torch.int32)
|
||
self.pcp_padded_slot_mapping = torch.zeros(
|
||
self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs,
|
||
dtype=torch.int32,
|
||
device=self.device)
|
||
self.num_actual_tokens_pcp_padded = 0
|
||
if self.speculative_config and self.pcp_size > 1:
|
||
self.input_ids_pcp_full = self._make_buffer(self.max_num_tokens,
|
||
dtype=torch.int32)
|
||
self.query_start_loc_pcp_full = self._make_buffer(
|
||
self.max_num_reqs + 1, dtype=torch.int32)
|
||
self.positions_pcp_full = torch.zeros(self.max_num_tokens,
|
||
dtype=torch.int64,
|
||
device="cpu",
|
||
pin_memory=True)
|
||
self.decode_token_per_req += self.speculative_config.num_speculative_tokens
|
||
self.positions_pcp_full_np = self.positions_pcp_full.numpy()
|
||
self.decode_threshold = 1 + (
|
||
self.speculative_config.num_speculative_tokens
|
||
if self.speculative_config else 0)
|
||
|
||
self.use_aclgraph = self._use_aclgraph()
|
||
|
||
self.dynamic_eplb = self.ascend_config.dynamic_eplb or self.ascend_config.expert_map_record_path
|
||
if self.dynamic_eplb:
|
||
EPLBParamUtils.check_dynamic_eplb(self.ascend_config.dynamic_eplb)
|
||
EPLBParamUtils.check_expert_map_record_path(
|
||
self.ascend_config.expert_map_record_path)
|
||
self.is_eplb_warmuped = False
|
||
self.policy_type = self.ascend_config.eplb_policy_type
|
||
self.eplb_loader = D2DExpertWeightLoader()
|
||
self.manager = Manager()
|
||
self.shared_dict = self.manager.dict({
|
||
"expert_map": None,
|
||
"moe_load": None,
|
||
"expert_maps": None
|
||
})
|
||
self.eplb_process = EplbProcess(shared_dict=self.shared_dict,
|
||
policy_type=self.policy_type,
|
||
enable_d2d=True)
|
||
self.process = self.eplb_process._launch_process()
|
||
ascend_config = get_ascend_config()
|
||
self.eplb_updator = EplbUpdator(ascend_config, self.eplb_loader,
|
||
self.eplb_process, self.process)
|
||
# Input Batch
|
||
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
||
# `initialize_kv_cache` based on the kv cache config. However, as in
|
||
# https://github.com/vllm-project/vllm/pull/18298, due to some unknown
|
||
# reasons, we have to initialize the input batch before `load_model`,
|
||
# quantization + weight offloading will fail otherwise. As a temporary
|
||
# solution, we initialize the input batch here, and re-initialize it
|
||
# in `initialize_kv_cache` if the block_sizes here is different from
|
||
# the block_sizes in the kv cache config.
|
||
self.input_batch = NPUInputBatch(
|
||
max_num_reqs=self.max_num_reqs,
|
||
max_model_len=self.model_config.max_model_len,
|
||
max_num_batched_tokens=self.max_num_tokens,
|
||
device=self.device,
|
||
pin_memory=self.pin_memory,
|
||
vocab_size=self.model_config.get_vocab_size(),
|
||
block_sizes=[self.block_size],
|
||
kernel_block_sizes=[[self.cache_config.block_size]],
|
||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||
logitsprocs=build_logitsprocs(
|
||
self.vllm_config, self.device, self.pin_memory,
|
||
self.is_pooling_model,
|
||
self.vllm_config.model_config.logits_processors),
|
||
is_pooling_model=self.is_pooling_model,
|
||
num_speculative_tokens=(
|
||
self.vllm_config.speculative_config.num_speculative_tokens
|
||
if self.vllm_config.speculative_config else 0),
|
||
cp_kv_cache_interleave_size=self.parallel_config.
|
||
cp_kv_cache_interleave_size,
|
||
)
|
||
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
|
||
dtype=torch.int32)
|
||
# here we use int32
|
||
self.sampled_token_ids_pinned_cpu = torch.empty(
|
||
(self.max_num_reqs, 1),
|
||
dtype=torch.int32,
|
||
device="cpu",
|
||
pin_memory=self.pin_memory,
|
||
)
|
||
# for cleancode , actually the three attrs is defined in gpu_model_runner
|
||
self.execute_model_state: ExecuteModelState | None = None
|
||
# None in the first PP rank. The rest are set after load_model.
|
||
self.intermediate_tensors: IntermediateTensors | None = None
|
||
self.reorder_batch_threshold: int | None = None
|
||
|
||
def _init_device_properties(self) -> None:
|
||
self.num_sms = None
|
||
|
||
def _sync_device(self) -> None:
|
||
torch.npu.synchronize()
|
||
|
||
def _set_up_drafter(self):
|
||
# Set up speculative decoding.
|
||
self.spec_attn_mask = None
|
||
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
|
||
SuffixDecodingProposer]] = None
|
||
self.actual_seq_lengths_q: list[int] = []
|
||
self.decode_token_per_req = 1
|
||
if self.speculative_config:
|
||
spec_token_num = self.speculative_config.num_speculative_tokens
|
||
assert spec_token_num > 0
|
||
self.decode_token_per_req = 1 + spec_token_num
|
||
self.spec_attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask(
|
||
)
|
||
if get_pp_group().is_last_rank:
|
||
self.drafter = self._get_drafter()
|
||
self.rejection_sampler = RejectionSampler(self.sampler)
|
||
self.actual_seq_lengths_q = list(
|
||
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||
self.decode_token_per_req))
|
||
self.discard_request_indices = self._make_buffer(self.max_num_reqs,
|
||
dtype=torch.int64)
|
||
self.num_discarded_requests = 0
|
||
|
||
def _get_drafter(self):
|
||
return get_spec_decode_method(self.speculative_config.method,
|
||
self.vllm_config, self.device, self)
|
||
|
||
def _use_aclgraph(self) -> bool:
|
||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.mode == CompilationMode.VLLM_COMPILE and not self.model_config.enforce_eager
|
||
|
||
def _skip_all_reduce_acorss_dp_group(self) -> bool:
|
||
# NOTE: We can skip the all_reduce operation and avoid paading tokens
|
||
# to max_tokens_acrodd_dp in D nodes. In MoE models, we must ensure that
|
||
# num_tokens DOES NOT exceed mc2_tokens_capacity which means that moe_comm_method
|
||
# of each rank is MC2. For dense models, skipping all_reduce is not necessary
|
||
# since collective-communication is not time-consuming since dp_size in dense
|
||
# model deployments is always small and can be overlapped by async scheduling.
|
||
if not is_moe_model(self.vllm_config):
|
||
return False
|
||
if self.compilation_config.cudagraph_capture_sizes:
|
||
potential_max_num_tokens = self.compilation_config.max_cudagraph_capture_size
|
||
else:
|
||
potential_max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
|
||
# To ensure skipping all_reduce across dp group is valid, we need to ensure that
|
||
# moe_comm_method of each rank is MC2 and recomputation would never happen in D
|
||
# nodes. So here we check whether recompute_scheduler_enable is True.
|
||
return self.is_kv_consumer and self.ascend_config.recompute_scheduler_enable and select_moe_comm_method(
|
||
potential_max_num_tokens,
|
||
self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}
|
||
|
||
def _sync_metadata_across_dp(
|
||
self, num_tokens: int,
|
||
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
|
||
# TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in
|
||
# our case, we still need to sync the other two flags as well. So we need to
|
||
# include them in the all_reduce operation, and more over, we CANNOT skip it
|
||
# even if we are running in eager mode, which harms performance.
|
||
# FIXME: Restore the `or self.vllm_config.model_config.enforce_eager` here
|
||
# immediately once the other two flags are no longer needed.
|
||
if self.dp_size == 1:
|
||
return num_tokens, None, with_prefill
|
||
|
||
if self._skip_all_reduce_acorss_dp_group():
|
||
num_tokens_after_padding = torch.tensor([num_tokens] *
|
||
self.dp_size,
|
||
device="cpu",
|
||
dtype=torch.int32)
|
||
return num_tokens, num_tokens_after_padding, with_prefill
|
||
|
||
# Sync num_tokens, with_prefill across dp ranks
|
||
num_tokens_tensor = torch.tensor([
|
||
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
|
||
],
|
||
dtype=torch.int32,
|
||
device="cpu")
|
||
|
||
flags_tensor = torch.tensor([int(with_prefill)],
|
||
dtype=torch.int32,
|
||
device="cpu")
|
||
|
||
packed_tensor = torch.cat([num_tokens_tensor, flags_tensor])
|
||
# use cpu_group to avoid cpu synchronization issue.
|
||
# it can be overlapped with main moell execution on npu.
|
||
dist.all_reduce(packed_tensor, group=get_dp_group().cpu_group)
|
||
|
||
# Unpack the results
|
||
num_tokens_across_dp = packed_tensor[:-1]
|
||
synced_flags = packed_tensor[-1:]
|
||
max_tokens_across_dp = torch.max(num_tokens_across_dp).item()
|
||
global_with_prefill = bool(synced_flags[0])
|
||
|
||
# Create a tensor for num_tokens_after_padding
|
||
num_tokens_after_padding = torch.tensor([max_tokens_across_dp] *
|
||
self.dp_size,
|
||
device="cpu",
|
||
dtype=torch.int32)
|
||
|
||
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill
|
||
|
||
def get_model(self) -> nn.Module:
|
||
# get raw model out of the aclgraph wrapper.
|
||
if isinstance(self.model, ACLGraphWrapper):
|
||
return self.model.unwrap()
|
||
return self.model
|
||
|
||
def _make_attention_mask(self, attn_state) -> torch.Tensor:
|
||
# pcp situation.
|
||
if self.attn_mask_builder is None:
|
||
raise ValueError("Attn mask builder is None")
|
||
# Pooling situation.
|
||
if self.model_config.runner_type == "pooling":
|
||
return self.attn_mask_builder.get_attn_mask(2048, torch.bool)
|
||
|
||
if self.vllm_config.model_config.use_mla:
|
||
if self.pcp_size > 1:
|
||
return self.attn_mask_builder.get_pcp_mla_mask(self.dtype)
|
||
# mla prefill
|
||
if attn_state != AscendAttentionState.DecodeOnly:
|
||
return self.attn_mask_builder.get_mla_mask(self.dtype)
|
||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||
|
||
def generate_kv_idx(self, scheduler_output):
|
||
if not self.pcp_size > 1:
|
||
return
|
||
self.cp_kv_recover_idx_for_chunk = [[] for _ in range(self.pcp_size)]
|
||
|
||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||
req_id]
|
||
is_prefill = self.input_batch.num_computed_tokens_cpu[
|
||
i] < self.input_batch.num_prompt_tokens[i]
|
||
if is_prefill:
|
||
num_cp_padded_scheduled_tokens = cdiv(
|
||
num_scheduled_tokens,
|
||
2 * self.pcp_size) * (2 * self.pcp_size)
|
||
full_indices = list(
|
||
range(self.max_num_tokens * self.pcp_size * self.dcp_size +
|
||
self.pcp_size * self.dcp_size * self.max_num_reqs))
|
||
chunk_size = num_cp_padded_scheduled_tokens // (2 *
|
||
self.pcp_size)
|
||
num_added_recover_tokens = len(
|
||
self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_size
|
||
for rank in range(self.pcp_size):
|
||
self.cp_kv_recover_idx_for_chunk[rank].extend(
|
||
full_indices[rank * chunk_size +
|
||
num_added_recover_tokens:(rank + 1) *
|
||
chunk_size + num_added_recover_tokens])
|
||
self.cp_kv_recover_idx_for_chunk[rank].extend(
|
||
full_indices[num_cp_padded_scheduled_tokens -
|
||
(rank + 1) * chunk_size +
|
||
num_added_recover_tokens:
|
||
num_cp_padded_scheduled_tokens -
|
||
rank * chunk_size +
|
||
num_added_recover_tokens])
|
||
|
||
cp_kv_recover_idx_for_chunk = torch.from_numpy(
|
||
np.concatenate(
|
||
self.cp_kv_recover_idx_for_chunk)).to(device=self.device)
|
||
cp_kv_recover_idx_for_chunk.copy_(torch.tensor(
|
||
np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()),
|
||
non_blocking=True)
|
||
self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to(
|
||
torch.float32).argsort().to(torch.int32)
|
||
|
||
def _prepare_inputs(
|
||
self,
|
||
scheduler_output: "SchedulerOutput",
|
||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||
) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor,
|
||
int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor],
|
||
Optional[torch.Tensor], Optional[torch.Tensor], int]:
|
||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||
assert total_num_scheduled_tokens > 0
|
||
num_reqs = self.input_batch.num_reqs
|
||
assert num_reqs > 0
|
||
|
||
# OPTIMIZATION: Start copying the block table first.
|
||
# This way, we can overlap the copy with the following CPU operations.
|
||
self.input_batch.block_table.commit_block_table(num_reqs)
|
||
|
||
# Get the number of scheduled tokens for each request.
|
||
req_ids = self.input_batch.req_ids
|
||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||
|
||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||
num_scheduled_tokens)
|
||
_, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
|
||
positions_np = np.add(
|
||
self.input_batch.num_computed_tokens_cpu[req_indices],
|
||
arange,
|
||
)
|
||
|
||
self.input_batch.block_table.compute_slot_mapping(
|
||
req_indices, positions_np)
|
||
self.input_batch.block_table.commit_slot_mapping(
|
||
total_num_scheduled_tokens)
|
||
|
||
total_num_pcp_pads = 0
|
||
if self.pcp_size > 1:
|
||
if not self.vllm_config.model_config.use_mla:
|
||
self.generate_kv_idx(scheduler_output)
|
||
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
|
||
tokens)
|
||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||
total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
|
||
total_num_pcp_pads = torch.sum(self.num_pcp_pads).item()
|
||
else:
|
||
position_pcp, pcp_unpad_mask = None, None
|
||
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
|
||
|
||
max_num_scheduled_tokens = max(tokens)
|
||
if not scheduler_output.scheduled_spec_decode_tokens:
|
||
num_valid_tokens = np.array(tokens, dtype=np.int32)
|
||
else:
|
||
num_valid_tokens = np.array([
|
||
num_tokens -
|
||
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
|
||
for num_tokens, i in zip(tokens, req_ids)
|
||
],
|
||
dtype=np.int32)
|
||
|
||
if (self.use_aclgraph and total_num_scheduled_tokens
|
||
<= self.cudagraph_batch_sizes[-1]):
|
||
# Add padding to the batch size.
|
||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||
total_num_scheduled_tokens)
|
||
elif self.use_aclgraph and enable_sp(self.vllm_config):
|
||
# When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size,
|
||
# the model will fall back to running its FX graph in eager mode.
|
||
# In this case, when sequence parallelism is enabled, we need to pad tokens to align
|
||
# with tp_size because pad_size cannot be captured by the FX graph
|
||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||
num_input_tokens = math.ceil(
|
||
total_num_scheduled_tokens / tp_size) * tp_size
|
||
else:
|
||
# Eager mode.
|
||
num_input_tokens = total_num_scheduled_tokens
|
||
|
||
# Get the attention state.
|
||
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
||
num_valid_tokens)
|
||
self.attn_state = attn_state # type: ignore
|
||
|
||
# Determine if it's a splitfuse batch
|
||
with_prefill = attn_state not in [
|
||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||
]
|
||
|
||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||
|
||
# Get info across DP ranks.
|
||
# NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP,
|
||
# Otherwise, it's just max_tokens_across_dp_cpu
|
||
(maybe_padded_num_tokens, num_tokens_across_dp,
|
||
with_prefill) = self._sync_metadata_across_dp(num_input_tokens,
|
||
with_prefill)
|
||
|
||
# TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens
|
||
# We should consider removing maybe_padded_num_tokens later
|
||
num_input_tokens = maybe_padded_num_tokens
|
||
|
||
# Hot-Swap lora model
|
||
if self.lora_config:
|
||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||
|
||
# Get request indices.
|
||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||
num_scheduled_tokens)
|
||
|
||
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
|
||
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||
num_scheduled_tokens)
|
||
|
||
if self.pcp_size > 1:
|
||
positions_np = self.positions.np[:total_num_scheduled_tokens]
|
||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
||
position_pcp[:total_num_scheduled_tokens],
|
||
out=positions_np)
|
||
else:
|
||
self.positions.np[:total_num_scheduled_tokens] = positions_np
|
||
|
||
# Calculate M-RoPE positions.
|
||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||
if self.uses_mrope:
|
||
self._calc_mrope_positions(scheduler_output)
|
||
|
||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||
self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
|
||
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
|
||
non_blocking=True)
|
||
|
||
# Get token indices.
|
||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
||
# where M is the max_model_len.
|
||
token_indices = (positions_np +
|
||
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||
token_indices_tensor = torch.from_numpy(token_indices)
|
||
# Prepare input_ids.
|
||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||
# because torch.index_select is much faster than np.take for large
|
||
# tensors.
|
||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||
0,
|
||
token_indices_tensor,
|
||
out=self.input_ids.cpu[:total_num_scheduled_tokens])
|
||
if self.enable_prompt_embeds:
|
||
is_token_ids = self.input_batch.is_token_ids_tensor.flatten()
|
||
torch.index_select(
|
||
is_token_ids,
|
||
0,
|
||
token_indices_tensor,
|
||
out=self.is_token_ids.cpu[:total_num_scheduled_tokens])
|
||
|
||
# Because we did not pre-allocate a massive prompt_embeds CPU tensor on
|
||
# the InputBatch, we need to fill in the prompt embeds into the expected
|
||
# spots in the GpuModelRunner's pre-allocated prompt_embeds tensor.
|
||
if self.input_batch.req_prompt_embeds and (self.is_multimodal_model or
|
||
self.enable_prompt_embeds):
|
||
output_idx = 0
|
||
for req_idx in range(num_reqs):
|
||
num_sched = num_scheduled_tokens[req_idx]
|
||
|
||
# Skip if this request doesn't have embeddings
|
||
if req_idx not in self.input_batch.req_prompt_embeds:
|
||
output_idx += num_sched
|
||
continue
|
||
|
||
# Skip if no tokens scheduled
|
||
if num_sched <= 0:
|
||
output_idx += num_sched
|
||
continue
|
||
|
||
req_embeds = self.input_batch.req_prompt_embeds[req_idx]
|
||
start_pos = self.input_batch.num_computed_tokens_cpu[req_idx]
|
||
|
||
# Skip if trying to read beyond available embeddings
|
||
if start_pos >= req_embeds.shape[0]:
|
||
output_idx += num_sched
|
||
continue
|
||
|
||
# Copy available embeddings
|
||
end_pos = start_pos + num_sched
|
||
actual_end = min(end_pos, req_embeds.shape[0])
|
||
actual_num_sched = actual_end - start_pos
|
||
|
||
if actual_num_sched > 0:
|
||
self.inputs_embeds.cpu[output_idx:output_idx +
|
||
actual_num_sched].copy_(
|
||
req_embeds[start_pos:actual_end]
|
||
)
|
||
|
||
output_idx += num_sched
|
||
|
||
self.query_start_loc.np[0] = 0
|
||
self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens
|
||
self.query_start_loc.np[num_reqs + 1:].fill(cu_num_tokens[-1])
|
||
self.query_start_loc.copy_to_gpu()
|
||
|
||
self.seq_lens.np[:num_reqs] = (
|
||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||
num_scheduled_tokens)
|
||
self.seq_lens.copy_to_gpu()
|
||
|
||
self.seq_lens.gpu[num_reqs:].fill_(0)
|
||
|
||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||
|
||
# Copy the tensors to the NPU.
|
||
self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens,
|
||
cu_num_tokens)
|
||
self.positions.cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
|
||
self.positions.copy_to_gpu()
|
||
|
||
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
||
num_valid_tokens)
|
||
self.attn_mask = self._make_attention_mask(attn_state)
|
||
self.attn_state = attn_state # type: ignore
|
||
|
||
self.with_prefill = with_prefill
|
||
self.num_tokens_across_dp = num_tokens_across_dp
|
||
attn_metadata: dict[str, Any] = {}
|
||
|
||
# Record the index of requests that should not be sampled,
|
||
# so that we could clear the sampled tokens before returning
|
||
num_tokens = [
|
||
self.requests[r].num_tokens for r in self.input_batch.req_ids
|
||
]
|
||
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
||
base_num_reqs = self.input_batch.num_reqs
|
||
num_reqs = base_num_reqs
|
||
if self.pcp_size > 1:
|
||
# while pcp > 1, we need the original num_scheduled_tokens before split
|
||
# to calculate discard_requests_mask
|
||
tokens_original = [
|
||
scheduler_output.num_scheduled_tokens[i] for i in req_ids
|
||
]
|
||
original_seq_lens_np = (
|
||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||
np.array(tokens_original, dtype=np.int32))
|
||
discard_requests_mask = original_seq_lens_np < num_tokens_np
|
||
else:
|
||
discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np
|
||
|
||
discard_request_indices = np.nonzero(discard_requests_mask)[0]
|
||
self.num_discarded_requests = len(discard_request_indices)
|
||
self.discard_request_indices.np[:self.num_discarded_requests] = (
|
||
discard_request_indices)
|
||
self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
|
||
|
||
# _prepare_inputs may reorder the batch, so we must gather
|
||
# multi-modal outputs after that to ensure the correct order
|
||
if self.is_multimodal_model:
|
||
with self.maybe_get_ec_connector_output(
|
||
scheduler_output,
|
||
encoder_cache=self.encoder_cache,
|
||
):
|
||
# Run the multimodal encoder if any.
|
||
self._execute_mm_encoder(scheduler_output)
|
||
|
||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||
# embeddings), we always use embeddings (rather than token ids)
|
||
# as input to the multimodal model, even when the input is text.
|
||
input_ids = self.input_ids.gpu[:total_num_scheduled_tokens]
|
||
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
|
||
scheduler_output)
|
||
|
||
inputs_embeds = self.model.embed_input_ids(
|
||
input_ids,
|
||
multimodal_embeddings=mm_embeds,
|
||
is_multimodal=is_mm_embed,
|
||
)
|
||
|
||
# TODO(woosuk): Avoid the copy. Optimize.
|
||
self.inputs_embeds.gpu[:total_num_scheduled_tokens].copy_(
|
||
inputs_embeds)
|
||
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
|
||
input_ids = None
|
||
elif self.enable_prompt_embeds and get_pp_group().is_first_rank:
|
||
# Get the input embeddings for the tokens that are not input embeds,
|
||
# then put them into the appropriate positions.
|
||
# TODO(qthequartermasterman): Since even when prompt embeds are
|
||
# enabled, (a) not all requests will use prompt embeds, and (b)
|
||
# after the initial prompt is processed, the rest of the generated
|
||
# tokens will be token ids, it is not desirable to have the
|
||
# embedding layer outside of the acl graph all the time. The v0
|
||
# engine avoids this by "double compiling" the acl graph, once
|
||
# with input_ids and again with inputs_embeds, for all num_tokens.
|
||
# If a batch only has token ids, then including the embedding layer
|
||
# in the acl graph will be more performant (like in the else case
|
||
# below).
|
||
token_ids_idx = self.is_token_ids.gpu[:total_num_scheduled_tokens] \
|
||
.nonzero(as_tuple=False) \
|
||
.squeeze(1)
|
||
# Some tokens ids may need to become embeds
|
||
if token_ids_idx.numel() > 0:
|
||
token_ids = self.input_ids.gpu[token_ids_idx]
|
||
tokens_to_embeds = self.model.embed_input_ids(
|
||
input_ids=token_ids)
|
||
self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds
|
||
|
||
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
|
||
input_ids = None
|
||
else:
|
||
# For text-only models, we use token ids as input.
|
||
# While it is possible to use embeddings as input just like the
|
||
# multimodal models, it is not desirable for performance since
|
||
# then the embedding layer is not included in the ACL graph.
|
||
input_ids = self.input_ids.gpu[:num_input_tokens]
|
||
inputs_embeds = None
|
||
positions = self.positions.gpu[:num_input_tokens]
|
||
if self.uses_mrope:
|
||
positions = self.mrope_positions.gpu[:, :num_input_tokens]
|
||
|
||
# type: ignore
|
||
if get_pp_group().is_first_rank:
|
||
intermediate_tensors = None
|
||
else:
|
||
assert intermediate_tensors is not None
|
||
assert self.intermediate_tensors is not None
|
||
# If both flashcomm1 and pp are used simultaneously,
|
||
# the shape of the received data and the shape of the space to be copied to will not match,
|
||
# requiring a recalculation of the incoming data's shape.
|
||
tp_size = get_tensor_model_parallel_world_size()
|
||
num_input_tokens_with_flashcomm1 = num_input_tokens
|
||
if enable_sp():
|
||
num_input_tokens_with_flashcomm1 = (num_input_tokens +
|
||
tp_size - 1) // tp_size
|
||
for k, v in intermediate_tensors.items():
|
||
self.intermediate_tensors[
|
||
k][:num_input_tokens_with_flashcomm1].copy_(
|
||
v[:num_input_tokens_with_flashcomm1],
|
||
non_blocking=True)
|
||
intermediate_tensors = IntermediateTensors({
|
||
k:
|
||
v[:num_input_tokens_with_flashcomm1]
|
||
for k, v in self.intermediate_tensors.items()
|
||
})
|
||
|
||
use_spec_decode = len(
|
||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||
if not use_spec_decode:
|
||
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
||
# partial requests. While we should not sample any token
|
||
# from these partial requests, we do so for simplicity.
|
||
# We will ignore the sampled tokens from the partial requests.
|
||
# TODO: Support prompt logprobs.
|
||
spec_decode_metadata = None
|
||
if self.pcp_size * self.dcp_size > 1:
|
||
logits_indices = torch.from_numpy(
|
||
cu_num_tokens
|
||
) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1
|
||
logits_indices = logits_indices.pin_memory().to(
|
||
self.device, non_blocking=True)
|
||
else:
|
||
logits_indices = self.query_start_loc.gpu[1:num_reqs + 1] - 1
|
||
else:
|
||
# Get the number of draft tokens for each request.
|
||
# Iterate over the dictionary rather than all requests since not all
|
||
# requests have draft tokens.
|
||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||
# For chunked prefills, use -1 as mask rather than 0, as guided
|
||
# decoding may rollback speculative tokens.
|
||
num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32)
|
||
for req_id, draft_token_ids in (
|
||
scheduler_output.scheduled_spec_decode_tokens.items()):
|
||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||
num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if (
|
||
self.input_batch.num_computed_tokens_cpu[req_idx]
|
||
>= self.input_batch.num_prompt_tokens[req_idx]) else -1)
|
||
|
||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||
num_draft_tokens, cu_num_tokens, self.num_pcp_pads[:num_reqs])
|
||
logits_indices = spec_decode_metadata.logits_indices
|
||
|
||
# For DECODE only cuda graph of some attention backends (e.g., GDN).
|
||
self.num_decode_draft_tokens.np[:
|
||
num_reqs] = num_decode_draft_tokens
|
||
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
|
||
self.num_decode_draft_tokens.copy_to_gpu()
|
||
# save logits_indices for pcp spec decode usage
|
||
self.logits_indices = logits_indices
|
||
|
||
# Used in the below loop.
|
||
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||
num_computed_tokens_cpu = (
|
||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||
self.spec_decode_common_attn_metadata = None
|
||
if use_spec_decode and self.need_accepted_tokens:
|
||
self.num_accepted_tokens.np[:num_reqs] = (
|
||
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
|
||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||
self.num_accepted_tokens.copy_to_gpu()
|
||
|
||
if self.speculative_config and self.pcp_size > 1:
|
||
self._generate_pcp_mtp_input(
|
||
num_reqs, scheduler_output.total_num_scheduled_tokens,
|
||
scheduler_output.num_scheduled_tokens)
|
||
|
||
long_seq_metadata = self._generate_pcp_metadata(
|
||
total_num_scheduled_tokens)
|
||
# Prepare the attention metadata for each KV cache group and make layers
|
||
# in the same group share the same metadata.
|
||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||
self.kv_cache_config.kv_cache_groups):
|
||
# NOTE: This is strange, why did we use total_num_scheduled_tokens before?
|
||
slot_mapping_size = (total_num_scheduled_tokens
|
||
if self.pcp_size == 1 else
|
||
total_num_scheduled_tokens * self.pcp_size -
|
||
total_num_pcp_pads)
|
||
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
||
EncoderOnlyAttentionSpec):
|
||
# Encoder-only layers do not have KV cache, so we need to
|
||
# create a dummy block table and slot mapping for them.
|
||
blk_table_tensor = torch.zeros(
|
||
(num_reqs, 1),
|
||
dtype=torch.int32,
|
||
device=self.device,
|
||
)
|
||
slot_mapping = torch.zeros(
|
||
(total_num_scheduled_tokens, ),
|
||
dtype=torch.int64,
|
||
device=self.device,
|
||
)
|
||
else:
|
||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||
blk_table_tensor = blk_table.get_device_tensor()
|
||
blk_table.slot_mapping.gpu[slot_mapping_size:].fill_(0)
|
||
if self.pcp_size > 1:
|
||
slot_mapping_for_pcp = blk_table.slot_mapping.gpu[:
|
||
long_seq_metadata
|
||
.
|
||
num_actual_tokens_pcp_padded]
|
||
slot_mapping_for_pcp[slot_mapping_size:].fill_(-1)
|
||
assert pcp_unpad_mask is not None
|
||
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:
|
||
pcp_unpad_mask
|
||
.
|
||
shape[
|
||
0]]
|
||
pcp_padded_slot_mapping.fill_(-1)
|
||
pcp_padded_slot_mapping[
|
||
pcp_unpad_mask] = slot_mapping_for_pcp[:
|
||
slot_mapping_size]
|
||
slot_mapping_for_pcp[:long_seq_metadata.
|
||
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
|
||
blk_table.slot_mapping.gpu[:long_seq_metadata.num_actual_tokens_pcp_padded] = \
|
||
slot_mapping_for_pcp
|
||
slot_mapping = blk_table.slot_mapping.gpu
|
||
|
||
# NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs
|
||
# has been split to multiple parts, and there are 3 parts that is related to this
|
||
# `num_reqs`, we'll take `query_start_loc` as an example:
|
||
# 1. self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
|
||
# 2. get `num_reqs_padded`, this depends on dispatcher and which is why we have the
|
||
# following simplified `dispatch` logic here, we try to minimize the impact
|
||
# 3. query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1]
|
||
uniform_decode = (max_num_scheduled_tokens == self.uniform_decode_query_len) \
|
||
and (total_num_scheduled_tokens == max_num_scheduled_tokens * num_reqs)
|
||
|
||
# TODO: We should make this official ASAP. Also note that if we pad here,
|
||
# the builders won’t need to add any extra padding.
|
||
max_decode_tokens = self.scheduler_config.max_num_seqs * self.uniform_decode_query_len
|
||
if self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
||
uniform_decode and self.uniform_decode_query_len <= num_input_tokens <= max_decode_tokens:
|
||
num_reqs_padded = num_input_tokens // self.uniform_decode_query_len
|
||
pad_size = num_reqs_padded - num_reqs
|
||
if pad_size > 0:
|
||
last_query_loc = self.query_start_loc.np[num_reqs]
|
||
|
||
self.query_start_loc.np[
|
||
num_reqs + 1:num_reqs_padded + 1] = self.arange_np[
|
||
1:pad_size +
|
||
1] * self.uniform_decode_query_len + last_query_loc
|
||
self.query_start_loc.copy_to_gpu(num_reqs_padded + 1)
|
||
|
||
# So we are trying to simulate the behavior of GPUModelRunner's
|
||
# prepare_inputs for uniform decode mode by padding query_start_loc
|
||
num_reqs = num_reqs_padded
|
||
|
||
# Make AscendCommonAttentionMetadata
|
||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||
query_start_loc=self.query_start_loc.gpu[:num_reqs + 1],
|
||
query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + 1],
|
||
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
|
||
seq_lens=self.seq_lens.gpu[:num_reqs],
|
||
num_reqs=num_reqs,
|
||
num_actual_tokens=slot_mapping_size,
|
||
num_input_tokens=num_input_tokens,
|
||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||
# TODO: change this to the right block table for linear attn
|
||
block_table_tensor=blk_table_tensor[:num_reqs],
|
||
slot_mapping=slot_mapping,
|
||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||
positions=self.positions.gpu,
|
||
attn_mask=self.attn_mask,
|
||
spec_attn_mask=self.spec_attn_mask,
|
||
attn_state=self.attn_state,
|
||
max_query_len=max_num_scheduled_tokens,
|
||
decode_token_per_req=self.decode_token_per_req,
|
||
prefill_context_parallel_metadata=long_seq_metadata,
|
||
)
|
||
|
||
if self.speculative_config and self.pcp_size > 1:
|
||
# For pcp + spec decode, we flatten block_table
|
||
# to avoid irregular spec_attn_mask shape, e.g.,
|
||
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
|
||
# ori block_table: # [d0, d1, p0, p1, p2]
|
||
# (num_reqs_d + num_reqs_p, max_num_blocks),
|
||
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
|
||
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
|
||
ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs + 1] - \
|
||
self.query_start_loc_pcp_full.cpu[:num_reqs]
|
||
num_prefill_reqs = (ori_query_lens
|
||
> self.decode_threshold).sum().item()
|
||
num_decode_reqs = num_reqs - num_prefill_reqs
|
||
num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold
|
||
blk_table_tensor[
|
||
num_decode_reqs_flatten:num_decode_reqs_flatten +
|
||
num_prefill_reqs].copy_(
|
||
blk_table_tensor[num_decode_reqs:num_decode_reqs +
|
||
num_prefill_reqs].clone())
|
||
blk_table_tensor[:num_decode_reqs_flatten].copy_(
|
||
blk_table_tensor[:num_decode_reqs].repeat_interleave(
|
||
self.decode_threshold, dim=0))
|
||
common_attn_metadata.block_table_tensor = \
|
||
blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs]
|
||
|
||
if self.speculative_config and \
|
||
self.spec_decode_common_attn_metadata is None:
|
||
self.spec_decode_common_attn_metadata = common_attn_metadata
|
||
if self.speculative_config.method in ("eagle", "eagle3") and \
|
||
self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||
self.spec_decode_common_attn_metadata = \
|
||
self.spec_decode_common_attn_metadata.unpadded(
|
||
total_num_scheduled_tokens, base_num_reqs)
|
||
|
||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||
common_prefix_len = 0
|
||
extra_attn_metadata_args = {}
|
||
builder = attn_group.get_metadata_builder()
|
||
if isinstance(builder, GDNAttentionMetadataBuilder):
|
||
if use_spec_decode:
|
||
patch_torch_npu_argsort()
|
||
extra_attn_metadata_args = dict(
|
||
num_accepted_tokens=self.num_accepted_tokens.
|
||
gpu[:num_reqs],
|
||
num_decode_draft_tokens_cpu=self.
|
||
num_decode_draft_tokens.cpu[:num_reqs],
|
||
)
|
||
attn_metadata_i = builder.build(
|
||
common_prefix_len=common_prefix_len,
|
||
common_attn_metadata=common_attn_metadata,
|
||
**extra_attn_metadata_args)
|
||
elif self.model_config.runner_type == "pooling":
|
||
attn_metadata_i = builder.build(
|
||
common_prefix_len=common_prefix_len,
|
||
common_attn_metadata=common_attn_metadata,
|
||
**extra_attn_metadata_args)
|
||
else:
|
||
attn_metadata_i = builder.build(
|
||
common_prefix_len=common_prefix_len,
|
||
common_attn_metadata=common_attn_metadata,
|
||
model=self.get_model(),
|
||
**extra_attn_metadata_args)
|
||
|
||
for layer_name in attn_group.layer_names:
|
||
attn_metadata[layer_name] = attn_metadata_i
|
||
|
||
# update global cos, sin
|
||
update_cos_sin(positions)
|
||
|
||
if lmhead_tp_enable():
|
||
max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len
|
||
logits_indices = nn.functional.pad(
|
||
logits_indices,
|
||
(0, max_num_reqs_across_dp - logits_indices.shape[0]))
|
||
|
||
return (attn_metadata, positions, num_scheduled_tokens,
|
||
num_input_tokens, num_tokens_across_dp,
|
||
maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
|
||
input_ids, inputs_embeds, intermediate_tensors,
|
||
max_num_scheduled_tokens)
|
||
|
||
def _generate_process_reqs_hidden_states(self, maybe_padded_num_tokens,
|
||
input_ids, positions,
|
||
intermediate_tensors,
|
||
inputs_embeds):
|
||
assert self.model is not None
|
||
hidden_states = self.model(
|
||
input_ids=input_ids,
|
||
positions=positions,
|
||
intermediate_tensors=intermediate_tensors,
|
||
inputs_embeds=inputs_embeds,
|
||
**self._init_model_kwargs(maybe_padded_num_tokens))
|
||
|
||
forward_context = get_forward_context()
|
||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \
|
||
and not self.use_sparse:
|
||
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
|
||
if self.vllm_config.model_config.use_mla:
|
||
if self.pcp_size * self.dcp_size > 1:
|
||
# FIXME: Try using `auto_dispatch_capture=True`
|
||
update_mla_attn_dcp_pcp_params(self.update_stream,
|
||
forward_context,
|
||
maybe_padded_num_tokens)
|
||
else:
|
||
# FIXME: Try using `auto_dispatch_capture=True`
|
||
update_mla_attn_params(self.update_stream, forward_context,
|
||
maybe_padded_num_tokens,
|
||
self.speculative_config)
|
||
else:
|
||
if self.pcp_size * self.dcp_size > 1:
|
||
update_attn_dcp_pcp_params(self.update_stream,
|
||
forward_context,
|
||
maybe_padded_num_tokens)
|
||
else:
|
||
update_attn_params(self.update_stream, forward_context,
|
||
maybe_padded_num_tokens,
|
||
self.vllm_config)
|
||
|
||
if get_forward_context().sp_enabled and not isinstance(
|
||
hidden_states, IntermediateTensors):
|
||
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
||
pad_size = get_forward_context().pad_size
|
||
if pad_size > 0:
|
||
hidden_states = hidden_states[:-pad_size, :]
|
||
|
||
if self.pcp_size > 1:
|
||
hidden_states = get_pcp_group().all_gather(
|
||
hidden_states[:self.num_actual_tokens_pcp_padded //
|
||
self.pcp_size], 0)
|
||
hidden_states = torch.index_select(
|
||
hidden_states, 0,
|
||
self.pcp_allgather_restore_idx[:hidden_states.shape[0]])
|
||
return hidden_states
|
||
|
||
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
|
||
num_valid_tokens):
|
||
if self.model_config.runner_type == "pooling":
|
||
if isinstance(
|
||
self.kv_cache_config.kv_cache_groups[0].kv_cache_spec,
|
||
EncoderOnlyAttentionSpec):
|
||
attn_state = AscendAttentionState.PrefillNoCache
|
||
else:
|
||
attn_state = AscendAttentionState.PrefillCacheHit
|
||
elif np.array_equal(self.seq_lens.np[:num_reqs], num_scheduled_tokens):
|
||
attn_state = AscendAttentionState.PrefillNoCache
|
||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||
elif np.all(num_scheduled_tokens == 1):
|
||
attn_state = AscendAttentionState.DecodeOnly
|
||
if self.speculative_config and self.speculative_config.method == 'mtp':
|
||
# SpecDecoding now supports seq_len=1 and seq_len=2
|
||
# In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1
|
||
attn_state = AscendAttentionState.SpecDecoding
|
||
# Speculative decoding.
|
||
elif np.all(num_valid_tokens == 1):
|
||
if self.speculative_config and self.speculative_config.method == 'mtp':
|
||
attn_state = AscendAttentionState.SpecDecoding
|
||
else:
|
||
attn_state = AscendAttentionState.ChunkedPrefill
|
||
# splitfuse
|
||
elif self.scheduler_config.enable_chunked_prefill:
|
||
attn_state = AscendAttentionState.ChunkedPrefill
|
||
else:
|
||
attn_state = AscendAttentionState.PrefillCacheHit
|
||
return attn_state
|
||
|
||
def _calc_spec_decode_metadata(
|
||
self,
|
||
num_draft_tokens: np.ndarray,
|
||
cu_num_scheduled_tokens: np.ndarray,
|
||
num_pcp_pads: np.ndarray,
|
||
) -> SpecDecodeMetadata:
|
||
# Inputs:
|
||
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
|
||
# num_draft_tokens: [ 3, 0, 2, 0, 1]
|
||
# Outputs:
|
||
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
|
||
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
|
||
# 206, 207, 208]
|
||
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
|
||
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
|
||
|
||
# Compute the logits indices.
|
||
# [4, 1, 3, 1, 2]
|
||
num_sampled_tokens = num_draft_tokens + 1
|
||
# Step 1. [4, 5, 8, 9, 11]
|
||
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
|
||
total_num_sampled_tokens = cu_num_sampled_tokens[-1]
|
||
# Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||
cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
|
||
num_sampled_tokens)
|
||
# Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||
arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
|
||
# Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
||
logits_indices = np.repeat(
|
||
cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
|
||
# Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||
logits_indices += arange
|
||
|
||
# while pcp > 1, decode results may contain padding (from pcp all-gather),
|
||
# update logits_indices after getting draft_token_ids from ori logits_indices
|
||
if self.pcp_size > 1:
|
||
cu_num_scheduled_tokens = cu_num_scheduled_tokens * self.pcp_size - num_pcp_pads
|
||
logits_indices_pcp = np.repeat(
|
||
cu_num_scheduled_tokens - num_sampled_tokens,
|
||
num_sampled_tokens)
|
||
logits_indices_pcp += arange
|
||
logits_indices_pcp = torch.from_numpy(
|
||
logits_indices_pcp).pin_memory().to(self.device,
|
||
non_blocking=True)
|
||
|
||
# Compute the bonus logits indices.
|
||
bonus_logits_indices = cu_num_sampled_tokens - 1
|
||
|
||
# Compute the draft logits indices.
|
||
# [3, 3, 5, 5, 6]
|
||
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
|
||
total_num_draft_tokens = cu_num_draft_tokens[-1]
|
||
# [0, 0, 0, 3, 3, 5]
|
||
cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
|
||
num_draft_tokens)
|
||
# [0, 1, 2, 0, 1, 0]
|
||
arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
|
||
# [0, 0, 0, 5, 5, 9]
|
||
target_logits_indices = np.repeat(
|
||
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
|
||
# [0, 1, 2, 5, 6, 9]
|
||
target_logits_indices += arange
|
||
|
||
# TODO: Optimize the CPU -> NPU copy.
|
||
cu_num_draft_tokens = (
|
||
torch.from_numpy(cu_num_draft_tokens).pin_memory().to(
|
||
self.device, non_blocking=True))
|
||
cu_num_sampled_tokens = (
|
||
torch.from_numpy(cu_num_sampled_tokens).pin_memory().to(
|
||
self.device, non_blocking=True))
|
||
logits_indices = (torch.from_numpy(logits_indices).pin_memory().to(
|
||
self.device, non_blocking=True))
|
||
target_logits_indices = (
|
||
torch.from_numpy(target_logits_indices).pin_memory().to(
|
||
self.device, non_blocking=True))
|
||
bonus_logits_indices = torch.from_numpy(
|
||
bonus_logits_indices).pin_memory().to(self.device,
|
||
non_blocking=True)
|
||
|
||
# Compute the draft token ids.
|
||
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
|
||
draft_token_ids = self.input_ids.gpu[logits_indices]
|
||
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||
if self.pcp_size > 1:
|
||
logits_indices = logits_indices_pcp
|
||
return SpecDecodeMetadata(
|
||
draft_token_ids=draft_token_ids,
|
||
num_draft_tokens=num_draft_tokens.tolist(),
|
||
cu_num_draft_tokens=cu_num_draft_tokens,
|
||
cu_num_sampled_tokens=cu_num_sampled_tokens,
|
||
target_logits_indices=target_logits_indices,
|
||
bonus_logits_indices=bonus_logits_indices,
|
||
logits_indices=logits_indices,
|
||
)
|
||
|
||
def propose_draft_token_ids(
|
||
self,
|
||
valid_sampled_token_ids: torch.Tensor | list[list[int]],
|
||
sampling_metadata: SamplingMetadata,
|
||
scheduler_output: "SchedulerOutput",
|
||
spec_decode_metadata: SpecDecodeMetadata,
|
||
positions: torch.Tensor,
|
||
num_scheduled_tokens: int,
|
||
hidden_states: torch.Tensor,
|
||
attn_metadata: dict[str, Any],
|
||
aux_hidden_states: torch.Tensor = None,
|
||
) -> Optional[list[list[int]]]:
|
||
if not self.drafter:
|
||
# Speculative decoding is not enabled.
|
||
draft_token_ids = None
|
||
else:
|
||
draft_token_ids = self.drafter.generate_token_ids(
|
||
valid_sampled_token_ids, sampling_metadata, scheduler_output,
|
||
spec_decode_metadata, positions, num_scheduled_tokens,
|
||
hidden_states, aux_hidden_states)
|
||
return draft_token_ids
|
||
|
||
@staticmethod
|
||
def get_finished_kv_transfer(
|
||
scheduler_output: "SchedulerOutput",
|
||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||
if has_kv_transfer_group():
|
||
return get_kv_transfer_group().get_finished(
|
||
scheduler_output.finished_req_ids)
|
||
return None, None
|
||
|
||
@torch.inference_mode()
|
||
def execute_model(
|
||
self,
|
||
scheduler_output: "SchedulerOutput",
|
||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||
) -> Union[ModelRunnerOutput, IntermediateTensors] | None:
|
||
if self.execute_model_state is not None:
|
||
raise RuntimeError("State error: sample_tokens() must be called "
|
||
"after execute_model() returns None.")
|
||
|
||
with ProfileExecuteDuration().capture_async("prepare input"):
|
||
self._update_states(scheduler_output)
|
||
if has_ec_transfer() and get_ec_transfer().is_producer:
|
||
with self.maybe_get_ec_connector_output(
|
||
scheduler_output,
|
||
encoder_cache=self.encoder_cache,
|
||
):
|
||
self._execute_mm_encoder(scheduler_output)
|
||
return make_empty_encoder_model_runner_output(
|
||
scheduler_output)
|
||
|
||
if not scheduler_output.total_num_scheduled_tokens:
|
||
if not has_kv_transfer_group():
|
||
logger.debug(
|
||
"skip this step for we receive the data from remote disaggregate prefill node"
|
||
)
|
||
# Return empty ModelRunnerOuptut if there's no work to do.
|
||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||
return self.kv_connector_no_forward(scheduler_output,
|
||
self.vllm_config)
|
||
|
||
if self.dynamic_eplb:
|
||
self.eplb_updator.forward_before()
|
||
|
||
(attn_metadata, positions, num_scheduled_tokens_np,
|
||
num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens,
|
||
logits_indices, spec_decode_metadata, input_ids, inputs_embeds,
|
||
intermediate_tensors,
|
||
max_query_len) = (self._prepare_inputs(scheduler_output,
|
||
intermediate_tensors))
|
||
|
||
if self.dynamic_eplb:
|
||
self.eplb_updator.take_update_info_from_eplb_process()
|
||
|
||
# prevent debugger is None
|
||
need_dump = self.dump_enable and self.debugger is not None
|
||
if need_dump:
|
||
assert self.debugger is not None
|
||
dbg_cfg = getattr(self.debugger, "config", None)
|
||
dump_level = str(
|
||
getattr(dbg_cfg, "level",
|
||
"L1")).upper() if dbg_cfg is not None else "L1"
|
||
if dump_level in ("L0", "MIX"):
|
||
self.debugger.start(model=self.model)
|
||
else:
|
||
self.debugger.start()
|
||
|
||
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
||
scheduler_output.total_num_scheduled_tokens
|
||
== self.input_batch.num_reqs * max_query_len)
|
||
has_lora = len(self.input_batch.lora_id_to_lora_request) > 0
|
||
aclgraph_runtime_mode, batch_descriptor = \
|
||
self.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
|
||
|
||
# Run forward pass
|
||
with ProfileExecuteDuration().capture_async("forward"):
|
||
with set_ascend_forward_context(
|
||
attn_metadata,
|
||
self.vllm_config,
|
||
num_tokens=num_input_tokens,
|
||
num_tokens_across_dp=num_tokens_across_dp,
|
||
with_prefill=self.with_prefill,
|
||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||
batch_descriptor=batch_descriptor,
|
||
num_actual_tokens=scheduler_output.
|
||
total_num_scheduled_tokens,
|
||
prefetch_stream=self.prefetch_stream,
|
||
model_instance=self.model,
|
||
weight_prefetch_method=self.weight_prefetch_method):
|
||
self.maybe_setup_kv_connector(scheduler_output)
|
||
|
||
hidden_states = self._generate_process_reqs_hidden_states(
|
||
maybe_padded_num_tokens, input_ids, positions,
|
||
intermediate_tensors, inputs_embeds)
|
||
|
||
self.maybe_wait_for_kv_save()
|
||
finished_sending, finished_recving = self.get_finished_kv_transfer(
|
||
scheduler_output)
|
||
|
||
aux_hidden_states = None
|
||
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
|
||
hidden_states, aux_hidden_states = hidden_states
|
||
|
||
kv_connector_output = KVConnectorOutput(
|
||
finished_sending=finished_sending,
|
||
finished_recving=finished_recving)
|
||
finished_sending = None
|
||
finished_recving = None
|
||
with ProfileExecuteDuration().capture_async("post process"):
|
||
# Broadcast PP output for external_launcher (torchrun)
|
||
# to make sure we are synced across pp ranks
|
||
# TODO: Support overlapping mirco-batches
|
||
# https://github.com/vllm-project/vllm/issues/18019
|
||
broadcast_pp_output = \
|
||
self.parallel_config.distributed_executor_backend \
|
||
== "external_launcher" and len(get_pp_group().ranks) > 0
|
||
if not get_pp_group().is_last_rank:
|
||
# For mid-pipeline stages, return the hidden states.
|
||
if not broadcast_pp_output:
|
||
hidden_states.kv_connector_output = kv_connector_output
|
||
self.kv_connector_output = kv_connector_output
|
||
if need_dump:
|
||
assert self.debugger is not None
|
||
self.debugger.stop()
|
||
self.debugger.step()
|
||
return hidden_states
|
||
assert isinstance(hidden_states, IntermediateTensors)
|
||
get_pp_group().send_tensor_dict(
|
||
hidden_states.tensors, all_gather_group=get_tp_group())
|
||
logits = None
|
||
else:
|
||
if self.input_batch.pooling_params:
|
||
pool_output = self._pool(
|
||
hidden_states,
|
||
scheduler_output.total_num_scheduled_tokens,
|
||
num_scheduled_tokens_np)
|
||
if need_dump:
|
||
assert self.debugger is not None
|
||
self.debugger.stop()
|
||
self.debugger.step()
|
||
return pool_output
|
||
# Sometimes, after the model is compiled through the AOT backend,
|
||
# the model output may become a list containing only one Tensor object.
|
||
if isinstance(hidden_states, list) and \
|
||
len(hidden_states) == 1 and \
|
||
isinstance(hidden_states[0], torch.Tensor):
|
||
hidden_states = hidden_states[0]
|
||
sample_hidden_states = hidden_states[logits_indices]
|
||
logits = self.model.compute_logits(sample_hidden_states)
|
||
if broadcast_pp_output:
|
||
model_output_broadcast_data = {
|
||
"logits": logits.contiguous(),
|
||
} if logits is not None else {}
|
||
model_output_broadcast_data = get_pp_group(
|
||
).broadcast_tensor_dict(model_output_broadcast_data,
|
||
src=len(get_pp_group().ranks) - 1)
|
||
assert model_output_broadcast_data is not None
|
||
logits = model_output_broadcast_data["logits"]
|
||
|
||
# Apply structured output bitmasks if present
|
||
self.execute_model_state = ExecuteModelState(
|
||
scheduler_output,
|
||
logits,
|
||
spec_decode_metadata,
|
||
hidden_states,
|
||
sample_hidden_states,
|
||
aux_hidden_states,
|
||
attn_metadata,
|
||
positions,
|
||
)
|
||
self.kv_connector_output = kv_connector_output
|
||
return None
|
||
|
||
@torch.inference_mode
|
||
def sample_tokens(
|
||
self, grammar_output: "GrammarOutput | None"
|
||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
|
||
kv_connector_output = self.kv_connector_output
|
||
self.kv_connector_output = None
|
||
|
||
if self.execute_model_state is None:
|
||
# Nothing to do (PP non-final rank case), output isn't used.
|
||
if not kv_connector_output:
|
||
return None # noqa
|
||
# In case of PP with kv transfer, we need to pass through the
|
||
# kv_connector_output
|
||
if kv_connector_output.is_empty():
|
||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||
|
||
output = copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||
output.kv_connector_output = kv_connector_output
|
||
return output
|
||
|
||
need_dump = self.dump_enable and self.debugger is not None
|
||
# Unpack ephemeral state.
|
||
(
|
||
scheduler_output,
|
||
logits,
|
||
spec_decode_metadata,
|
||
hidden_states,
|
||
sample_hidden_states,
|
||
aux_hidden_states,
|
||
attn_metadata,
|
||
positions,
|
||
) = self.execute_model_state
|
||
# Clear ephemeral state.
|
||
self.execute_model_state = None
|
||
|
||
# Apply structured output bitmasks if present.
|
||
if grammar_output is not None:
|
||
# here we are different from gpu_model_runner,
|
||
# the apply_grammar_bitmask uses torch.compile to optimize this,ascend does not support it now
|
||
logits_dtype = logits.dtype
|
||
logits = logits.to("cpu").float()
|
||
apply_grammar_bitmask(scheduler_output, grammar_output,
|
||
self.input_batch, logits)
|
||
logits = logits.to(self.device).to(logits_dtype)
|
||
|
||
with ProfileExecuteDuration().capture_async("Sample"):
|
||
sampler_output = self._sample(logits, spec_decode_metadata)
|
||
|
||
def propose_draft_token_ids(sampled_token_ids):
|
||
assert self.spec_decode_common_attn_metadata is not None
|
||
self._draft_token_ids = self.propose_draft_token_ids(
|
||
sampled_token_ids,
|
||
self.input_batch.sampling_metadata,
|
||
scheduler_output,
|
||
spec_decode_metadata,
|
||
positions,
|
||
scheduler_output.total_num_scheduled_tokens,
|
||
hidden_states,
|
||
attn_metadata,
|
||
aux_hidden_states,
|
||
)
|
||
|
||
(
|
||
logprobs_lists,
|
||
valid_sampled_token_ids,
|
||
prompt_logprobs_dict,
|
||
req_ids_output_copy,
|
||
req_id_to_index_output_copy,
|
||
invalid_req_indices,
|
||
) = self._bookkeeping_sync(
|
||
scheduler_output,
|
||
sampler_output,
|
||
logits,
|
||
hidden_states,
|
||
scheduler_output.total_num_scheduled_tokens,
|
||
spec_decode_metadata,
|
||
)
|
||
|
||
with ProfileExecuteDuration().capture_async("Draft"):
|
||
if self.speculative_config:
|
||
use_padded_batch_for_eagle = self.speculative_config and \
|
||
self.speculative_config.use_eagle() and \
|
||
not self.speculative_config.disable_padded_drafter_batch
|
||
if use_padded_batch_for_eagle:
|
||
# EAGLE speculative decoding can use the GPU sampled tokens
|
||
# as inputs, and does not need to wait for bookkeeping to finish.
|
||
propose_draft_token_ids(sampler_output.sampled_token_ids)
|
||
if self.speculative_config and not use_padded_batch_for_eagle:
|
||
# ngram and other speculative decoding methods use the sampled
|
||
# tokens on the CPU, so they are run after bookkeeping.
|
||
propose_draft_token_ids(valid_sampled_token_ids)
|
||
|
||
if has_kv_transfer_group():
|
||
get_kv_transfer_group().clear_connector_metadata()
|
||
|
||
extra_args = ({"kv_connector_output": kv_connector_output})
|
||
|
||
model_runner_output = ModelRunnerOutput(
|
||
req_ids=req_ids_output_copy,
|
||
req_id_to_index=req_id_to_index_output_copy,
|
||
sampled_token_ids=valid_sampled_token_ids,
|
||
logprobs=logprobs_lists,
|
||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||
pooler_output=[],
|
||
**extra_args,
|
||
)
|
||
|
||
durations = ProfileExecuteDuration().pop_captured_sync()
|
||
if durations:
|
||
dr_str = [
|
||
f"[{tag}]:{duration:.2f}ms"
|
||
for tag, duration in durations.items()
|
||
]
|
||
captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill"
|
||
logger.info("Profile execute duration [%s]:%s", captured_name,
|
||
" ".join(dr_str))
|
||
if self.dynamic_eplb:
|
||
self.eplb_updator.forward_end()
|
||
if not self.use_async_scheduling:
|
||
if need_dump:
|
||
assert self.debugger is not None
|
||
self.debugger.stop()
|
||
self.debugger.step()
|
||
return model_runner_output
|
||
|
||
if need_dump:
|
||
assert self.debugger is not None
|
||
self.debugger.stop()
|
||
self.debugger.step()
|
||
return AsyncGPUModelRunnerOutput(
|
||
model_runner_output=model_runner_output,
|
||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||
logprobs_tensors=sampler_output.logprobs_tensors,
|
||
invalid_req_indices=invalid_req_indices,
|
||
async_output_copy_stream=self.async_output_copy_stream,
|
||
vocab_size=self.input_batch.vocab_size,
|
||
)
|
||
|
||
# overwrite _sample for lmhead_tp_enable and need_accepted_tokens
|
||
def _sample(self, logits, spec_decode_metadata):
|
||
# Sample the next token and get logprobs if needed.
|
||
sampling_metadata = self.input_batch.sampling_metadata
|
||
if spec_decode_metadata is None:
|
||
if lmhead_tp_enable() and logits is not None:
|
||
logits = logits[:self.input_batch.num_reqs]
|
||
return self.sampler(
|
||
logits=logits,
|
||
sampling_metadata=sampling_metadata,
|
||
)
|
||
|
||
if lmhead_tp_enable() and logits is not None:
|
||
logits = logits[:len(spec_decode_metadata.logits_indices)]
|
||
sampler_output = self.rejection_sampler(
|
||
spec_decode_metadata,
|
||
None, # draft_probs
|
||
logits,
|
||
sampling_metadata,
|
||
)
|
||
if self.need_accepted_tokens: # TODO remove this if
|
||
self._update_states_after_model_execute(
|
||
sampler_output.sampled_token_ids)
|
||
return sampler_output
|
||
|
||
# TODO: remove this func after eagle_proposer is refactored and
|
||
# _bookkeeping_sync is moved after propose_draft_token_ids
|
||
def _bookkeeping_sync(
|
||
self,
|
||
scheduler_output: "SchedulerOutput",
|
||
sampler_output: SamplerOutput,
|
||
logits: torch.Tensor | None,
|
||
hidden_states: torch.Tensor,
|
||
num_scheduled_tokens: int,
|
||
spec_decode_metadata: SpecDecodeMetadata | None,
|
||
) -> tuple[
|
||
LogprobsLists | None,
|
||
list[list[int]],
|
||
dict[str, LogprobsTensors | None],
|
||
list[str],
|
||
dict[str, int],
|
||
list[int],
|
||
]:
|
||
# TODO: implement PR 28597 from vllm
|
||
discard_sampled_tokens_req_indices = \
|
||
self.discard_request_indices.np[:self.num_discarded_requests]
|
||
for i in discard_sampled_tokens_req_indices:
|
||
gen = self.input_batch.generators.get(int(i))
|
||
if gen is not None:
|
||
gen.set_offset(gen.get_offset() - 4)
|
||
|
||
# Copy some objects so they don't get modified after returning.
|
||
# This is important when using async scheduling.
|
||
req_ids_output_copy = self.input_batch.req_ids.copy()
|
||
req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
|
||
|
||
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
||
sampled_token_ids = sampler_output.sampled_token_ids
|
||
logprobs_tensors = sampler_output.logprobs_tensors
|
||
invalid_req_indices = []
|
||
cu_num_tokens: list[int] | None = None
|
||
if not self.use_async_scheduling:
|
||
# Get the valid generated tokens.
|
||
max_gen_len = sampled_token_ids.shape[-1]
|
||
if max_gen_len == 1:
|
||
# No spec decode tokens.
|
||
valid_sampled_token_ids = self._to_list(sampled_token_ids)
|
||
# Mask out the sampled tokens that should not be sampled.
|
||
for i in discard_sampled_tokens_req_indices:
|
||
valid_sampled_token_ids[int(i)].clear()
|
||
else:
|
||
# Includes spec decode tokens.
|
||
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
|
||
sampled_token_ids,
|
||
self.input_batch.vocab_size,
|
||
discard_sampled_tokens_req_indices,
|
||
return_cu_num_tokens=logprobs_tensors is not None,
|
||
)
|
||
else:
|
||
valid_sampled_token_ids = []
|
||
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
|
||
invalid_req_indices_set = set(invalid_req_indices)
|
||
|
||
if self.num_spec_tokens <= 0:
|
||
assert sampled_token_ids.shape[-1] == 1
|
||
# Cache the sampled tokens on the NPU and avoid CPU sync.
|
||
# These will be copied into input_ids in the next step
|
||
# when preparing inputs.
|
||
self.input_batch.prev_sampled_token_ids = sampled_token_ids
|
||
|
||
self.input_batch.prev_req_id_to_index = {
|
||
req_id: i
|
||
for i, req_id in enumerate(self.input_batch.req_ids)
|
||
if i not in invalid_req_indices_set
|
||
}
|
||
|
||
# Cache the sampled tokens in the model runner, so that the scheduler
|
||
# doesn't need to send them back.
|
||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
||
# the sampled tokens back, because there's no direct communication
|
||
# between the first-stage worker and the last-stage worker.
|
||
req_ids = self.input_batch.req_ids
|
||
for req_idx in range(num_sampled_tokens):
|
||
if self.use_async_scheduling:
|
||
sampled_ids = [
|
||
-1
|
||
] if req_idx not in invalid_req_indices_set else None
|
||
else:
|
||
sampled_ids = valid_sampled_token_ids[req_idx]
|
||
|
||
num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
|
||
|
||
if not sampled_ids:
|
||
continue
|
||
|
||
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
||
end_idx = start_idx + num_sampled_ids
|
||
assert end_idx <= self.max_model_len, (
|
||
"Sampled token IDs exceed the max model length. "
|
||
f"Total number of tokens: {end_idx} > max_model_len: "
|
||
f"{self.max_model_len}")
|
||
|
||
self.input_batch.token_ids_cpu[req_idx,
|
||
start_idx:end_idx] = sampled_ids
|
||
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
|
||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||
self.input_batch.num_tokens[req_idx] = end_idx
|
||
|
||
req_id = req_ids[req_idx]
|
||
req_state = self.requests[req_id]
|
||
req_state.output_token_ids.extend(sampled_ids)
|
||
|
||
logprobs_lists = (logprobs_tensors.tolists(cu_num_tokens)
|
||
if not self.use_async_scheduling
|
||
and logprobs_tensors is not None else None)
|
||
|
||
# Compute prompt logprobs if needed.
|
||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||
hidden_states[:num_scheduled_tokens],
|
||
scheduler_output.num_scheduled_tokens,
|
||
)
|
||
|
||
return (
|
||
logprobs_lists,
|
||
valid_sampled_token_ids,
|
||
prompt_logprobs_dict,
|
||
req_ids_output_copy,
|
||
req_id_to_index_output_copy,
|
||
invalid_req_indices,
|
||
)
|
||
|
||
def _build_dummy_attn_metadata(
|
||
self,
|
||
with_prefill: bool,
|
||
num_reqs: int,
|
||
num_tokens: int,
|
||
max_query_len: int,
|
||
num_scheduled_tokens: np.ndarray,
|
||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||
force_attention: bool = False,
|
||
) -> Optional[dict[str, Any]]:
|
||
attn_metadata: Optional[dict[str, Any]] = None
|
||
|
||
if force_attention or aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||
assert with_prefill is False, \
|
||
"Full decode graph only supports uniform batch now."
|
||
|
||
attn_metadata = {}
|
||
|
||
seq_lens = max_query_len
|
||
self.seq_lens.np[:num_reqs] = seq_lens
|
||
self.seq_lens.np[num_reqs:] = 0
|
||
self.seq_lens.copy_to_gpu()
|
||
|
||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||
num_scheduled_tokens)
|
||
|
||
self.query_start_loc.cpu[1:num_reqs +
|
||
1] = torch.Tensor(cu_num_tokens)
|
||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||
self.attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()
|
||
|
||
num_computed_tokens_cpu = (
|
||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||
|
||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||
self.kv_cache_config.kv_cache_groups):
|
||
block_table_tensor = self.input_batch.block_table[
|
||
kv_cache_group_id].get_device_tensor()
|
||
slot_mapping = self.input_batch.block_table[
|
||
kv_cache_group_id].slot_mapping
|
||
self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens,
|
||
dtype=torch.int32,
|
||
device=self.device)
|
||
long_seq_metadata = self._generate_pcp_metadata(num_tokens)
|
||
if long_seq_metadata is not None:
|
||
pcp_world_size = get_pcp_group().world_size
|
||
dcp_world_size = get_dcp_group().world_size
|
||
num_computed_tokens_of_pcp_dcp = [[
|
||
[0] * dcp_world_size for _ in range(pcp_world_size)
|
||
] for _ in range(num_tokens)]
|
||
long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
|
||
# QUESTION: Why do we separately set query_start_loc for spec in the first place?
|
||
# While in _prepare_inputs we don't?
|
||
if self.speculative_config:
|
||
self.query_start_loc.gpu[:num_reqs + 1] = torch.tensor(
|
||
[0] + self.actual_seq_lengths_q[:num_reqs],
|
||
device=self.device,
|
||
dtype=torch.int32)
|
||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||
query_start_loc=self.query_start_loc.gpu[:num_reqs + 1],
|
||
query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs +
|
||
1],
|
||
seq_lens_cpu=self.seq_lens.cpu,
|
||
seq_lens=self.seq_lens.gpu[:num_reqs],
|
||
num_reqs=num_reqs,
|
||
num_actual_tokens=num_tokens,
|
||
num_input_tokens=num_tokens,
|
||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||
block_table_tensor=block_table_tensor[:num_reqs],
|
||
slot_mapping=slot_mapping.gpu,
|
||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||
positions=self.positions.gpu,
|
||
attn_mask=self.attn_mask,
|
||
spec_attn_mask=self.spec_attn_mask,
|
||
attn_state=self.attn_state,
|
||
max_query_len=max_query_len,
|
||
decode_token_per_req=self.decode_token_per_req,
|
||
prefill_context_parallel_metadata=long_seq_metadata,
|
||
)
|
||
if self.pcp_size > 1:
|
||
common_attn_metadata.block_table_tensor = \
|
||
block_table_tensor[:num_reqs * self.decode_threshold]
|
||
attn_state = AscendAttentionState.DecodeOnly
|
||
if self.speculative_config and \
|
||
self.speculative_config.method == "mtp":
|
||
attn_state = AscendAttentionState.SpecDecoding
|
||
|
||
if vllm_version_is("0.12.0"):
|
||
common_metadata = CommonAttentionMetadata(
|
||
query_start_loc=self.query_start_loc.gpu[:num_reqs +
|
||
1],
|
||
query_start_loc_cpu=self.query_start_loc.
|
||
cpu[:num_reqs + 1],
|
||
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
|
||
seq_lens=self.seq_lens.cpu[:num_reqs],
|
||
num_reqs=num_reqs,
|
||
num_actual_tokens=num_tokens,
|
||
block_table_tensor=block_table_tensor[:num_reqs],
|
||
slot_mapping=slot_mapping.gpu,
|
||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||
max_query_len=max_query_len,
|
||
max_seq_len=seq_lens)
|
||
else:
|
||
common_metadata = CommonAttentionMetadata(
|
||
query_start_loc=self.query_start_loc.gpu[:num_reqs +
|
||
1],
|
||
query_start_loc_cpu=self.query_start_loc.
|
||
cpu[:num_reqs + 1],
|
||
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
|
||
seq_lens=self.seq_lens.cpu[:num_reqs],
|
||
num_reqs=num_reqs,
|
||
num_actual_tokens=num_tokens,
|
||
block_table_tensor=block_table_tensor[:num_reqs],
|
||
slot_mapping=slot_mapping.gpu,
|
||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||
max_query_len=max_query_len,
|
||
max_seq_len=seq_lens)
|
||
|
||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||
builder = attn_group.get_metadata_builder()
|
||
if isinstance(builder, GDNAttentionMetadataBuilder):
|
||
attn_metadata_gdn_attention = builder.build_for_cudagraph_capture(
|
||
common_metadata)
|
||
else:
|
||
attn_metadata_full_attention = builder.build_for_graph_capture(
|
||
common_attn_metadata, attn_state, self.get_model())
|
||
for layer_name in kv_cache_group_spec.layer_names:
|
||
if "linear_attn" in layer_name:
|
||
attn_metadata[
|
||
layer_name] = attn_metadata_gdn_attention
|
||
else:
|
||
attn_metadata[
|
||
layer_name] = attn_metadata_full_attention
|
||
|
||
return attn_metadata
|
||
|
||
def _generate_dummy_run_hidden_states(self, input_ids, positions,
|
||
num_tokens, intermediate_tensors,
|
||
inputs_embeds):
|
||
hidden_states = self.model(input_ids=input_ids,
|
||
positions=positions,
|
||
intermediate_tensors=intermediate_tensors,
|
||
inputs_embeds=inputs_embeds)
|
||
forward_context = get_forward_context()
|
||
assert forward_context is not None
|
||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
|
||
not forward_context.capturing and not self.use_sparse:
|
||
if self.vllm_config.model_config.use_mla:
|
||
# FIXME: Try using `auto_dispatch_capture=True`
|
||
if self.pcp_size * self.dcp_size > 1:
|
||
# FIXME: Try using `auto_dispatch_capture=True`
|
||
update_mla_attn_dcp_pcp_params(self.update_stream,
|
||
forward_context,
|
||
positions.shape[0])
|
||
else:
|
||
# FIXME: Try using `auto_dispatch_capture=True`
|
||
update_mla_attn_params(self.update_stream, forward_context,
|
||
num_tokens, self.speculative_config)
|
||
else:
|
||
if self.pcp_size * self.dcp_size > 1:
|
||
update_attn_dcp_pcp_params(self.update_stream,
|
||
forward_context,
|
||
positions.shape[0])
|
||
else:
|
||
update_attn_params(self.update_stream, forward_context,
|
||
num_tokens, self.vllm_config)
|
||
|
||
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
|
||
hidden_states, _ = hidden_states
|
||
else:
|
||
hidden_states = hidden_states
|
||
return hidden_states
|
||
|
||
@torch.inference_mode()
|
||
def _dummy_run(
|
||
self,
|
||
num_tokens: int,
|
||
with_prefill: bool = False,
|
||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||
force_attention: bool = False,
|
||
uniform_decode: bool = False,
|
||
is_profile: bool = False,
|
||
) -> torch.Tensor:
|
||
# only support eager mode and piecewise graph now
|
||
assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in {
|
||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||
}
|
||
# In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs.
|
||
# If sequence parallelism is enabled, it is essential to ensure that num_tokens is divisible by tp_size.
|
||
if self.use_aclgraph and enable_sp(self.vllm_config):
|
||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||
num_tokens = math.ceil(num_tokens / tp_size) * tp_size
|
||
|
||
# Force dummy run on prefill stage when this node is deemed as kv producer.
|
||
if self.is_kv_producer and not self.is_kv_consumer:
|
||
with_prefill = True
|
||
|
||
# Padding for DP
|
||
(num_tokens, num_tokens_across_dp,
|
||
with_prefill) = self._sync_metadata_across_dp(num_tokens,
|
||
with_prefill)
|
||
|
||
# If cudagraph_mode.decode_mode() == FULL and
|
||
# cudagraph_mode.seperate_routine(). This means that we are using
|
||
# different graphs and/or modes for mixed prefill-decode batches vs.
|
||
# uniform decode batches. A uniform decode batch means that all
|
||
# requests have identical query length, except a potential virtual
|
||
# request (shorter) in the batch account for padding.
|
||
# Uniform decode batch could either be common pure decode, where
|
||
# max_query_len == 1, or speculative decode, where
|
||
# max_query_len == 1 + num_spec_decode_tokens.
|
||
|
||
# When setting max_query_len = 1, we switch to and capture the optimized
|
||
# routine of FA2 for pure decode, i.e., Flashdecode + an optimization
|
||
# for GQA/MQA.
|
||
max_query_len = self.uniform_decode_query_len if uniform_decode else \
|
||
num_tokens
|
||
|
||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||
# for dummy run with LoRA so that the num_reqs collectively
|
||
# has num_tokens in total.
|
||
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
||
max_num_reqs = self.max_num_reqs
|
||
if uniform_decode:
|
||
num_reqs = cdiv(num_tokens, max_query_len)
|
||
num_scheduled_tokens_list = [max_query_len] * num_reqs
|
||
if num_tokens % max_query_len != 0:
|
||
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
|
||
else:
|
||
if with_prefill:
|
||
num_reqs = num_tokens
|
||
else:
|
||
num_reqs = (num_tokens + self.decode_token_per_req -
|
||
1) // self.decode_token_per_req
|
||
num_reqs = min(num_reqs, max_num_reqs)
|
||
min_tokens_per_req = num_tokens // num_reqs
|
||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||
assert sum(num_scheduled_tokens_list) == num_tokens
|
||
assert len(num_scheduled_tokens_list) == num_reqs
|
||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||
dtype=np.int32)
|
||
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
|
||
|
||
if not is_profile and self.dynamic_eplb:
|
||
self.eplb_updator.forward_before()
|
||
|
||
has_lora = True if self.lora_config and self.compilation_config.cudagraph_specialize_lora else False
|
||
_ag_mode, batch_descriptor = \
|
||
self.cudagraph_dispatcher.dispatch(num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
|
||
|
||
num_tokens_padded = batch_descriptor.num_tokens
|
||
num_reqs_padded = (batch_descriptor.num_reqs if
|
||
batch_descriptor.num_reqs is not None else num_reqs)
|
||
if num_tokens_across_dp is not None and num_tokens_padded != num_tokens:
|
||
# pad is needed if the pad of `num_tokens` is triggered inside CudagraphDispatcher
|
||
num_tokens_across_dp[:] = num_tokens_padded
|
||
num_scheduled_tokens = num_scheduled_tokens.repeat(num_reqs_padded)
|
||
|
||
# filter out the valid batch descriptor
|
||
if aclgraph_runtime_mode is not None:
|
||
# we allow forcing NONE when the dispatcher disagrees to support
|
||
# warm ups for aclgraph capture
|
||
if aclgraph_runtime_mode != CUDAGraphMode.NONE and aclgraph_runtime_mode != _ag_mode:
|
||
raise ValueError(
|
||
f"Aclgraph runtime mode mismatch at dummy_run. "
|
||
f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}.")
|
||
else:
|
||
aclgraph_runtime_mode = _ag_mode
|
||
|
||
# TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup
|
||
# and not supported in ASCEND now. We could remove it in the future.
|
||
attn_metadata = self._build_dummy_attn_metadata(
|
||
False,
|
||
num_reqs=num_reqs_padded,
|
||
num_tokens=num_tokens_padded,
|
||
max_query_len=max_query_len,
|
||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||
force_attention=force_attention,
|
||
num_scheduled_tokens=num_scheduled_tokens,
|
||
)
|
||
|
||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||
num_scheduled_tokens,
|
||
num_sampled_tokens):
|
||
# Make sure padding doesn't exceed max_num_tokens
|
||
assert num_tokens_padded <= self.max_num_tokens
|
||
if self.is_multimodal_model:
|
||
input_ids = None
|
||
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
|
||
elif self.enable_prompt_embeds:
|
||
input_ids = None
|
||
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
|
||
else:
|
||
input_ids = self.input_ids.gpu[:num_tokens_padded]
|
||
inputs_embeds = None
|
||
|
||
if self.uses_mrope:
|
||
positions = self.mrope_positions.gpu[:, :num_tokens_padded]
|
||
else:
|
||
positions = self.positions.gpu[:num_tokens_padded]
|
||
|
||
# update global cos, sin
|
||
update_cos_sin(positions)
|
||
|
||
if get_pp_group().is_first_rank:
|
||
intermediate_tensors = None
|
||
else:
|
||
# When PP and flashcomm1 are enabled, during dummy_run the estimated space should divide num_tokens by tp_size;
|
||
# otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading to incorrect memory estimation and potentially causing OOM.
|
||
actual_tokens = num_tokens
|
||
if enable_sp():
|
||
tp_size = get_tensor_model_parallel_world_size()
|
||
actual_tokens = num_tokens // tp_size
|
||
if self.intermediate_tensors is None:
|
||
self.intermediate_tensors = (
|
||
self.model.make_empty_intermediate_tensors(
|
||
batch_size=actual_tokens,
|
||
dtype=self.dtype,
|
||
device=self.device))
|
||
intermediate_tensors = IntermediateTensors({
|
||
k:
|
||
v[:num_tokens_padded]
|
||
for k, v in self.intermediate_tensors.items()
|
||
})
|
||
|
||
need_dummy_logits = (not is_profile and lmhead_tp_enable())
|
||
max_num_reqs_across_dp = max_num_reqs * self.uniform_decode_query_len
|
||
dummy_indices = torch.zeros(max_num_reqs_across_dp,
|
||
dtype=torch.int32)
|
||
|
||
def dummy_compute_logits(hidden_states):
|
||
if not need_dummy_logits:
|
||
return None
|
||
return self.model.compute_logits(hidden_states[dummy_indices])
|
||
|
||
def dummy_drafter_compute_logits(hidden_states):
|
||
if not need_dummy_logits or self.drafter is None:
|
||
return
|
||
if hasattr(self.drafter, "model") and hasattr(
|
||
self.drafter.model, "compute_logits"):
|
||
return self.drafter.model.compute_logits(
|
||
hidden_states[dummy_indices])
|
||
|
||
with set_ascend_forward_context(
|
||
attn_metadata,
|
||
self.vllm_config,
|
||
num_tokens=num_tokens_padded,
|
||
num_tokens_across_dp=num_tokens_across_dp,
|
||
with_prefill=with_prefill,
|
||
in_profile_run=is_profile,
|
||
num_actual_tokens=0,
|
||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||
batch_descriptor=batch_descriptor,
|
||
prefetch_stream=self.prefetch_stream,
|
||
model_instance=self.model,
|
||
weight_prefetch_method=self.weight_prefetch_method):
|
||
hidden_states = self._generate_dummy_run_hidden_states(
|
||
input_ids, positions, num_tokens_padded,
|
||
intermediate_tensors, inputs_embeds)
|
||
dummy_compute_logits(hidden_states)
|
||
|
||
if self.drafter:
|
||
self.drafter.dummy_run(
|
||
num_tokens=num_tokens_padded,
|
||
with_prefill=with_prefill,
|
||
num_reqs=num_reqs_padded,
|
||
num_tokens_across_dp=num_tokens_across_dp,
|
||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||
batch_descriptor=batch_descriptor,
|
||
dummy_compute_logits=dummy_drafter_compute_logits,
|
||
in_graph_capturing=not force_attention,
|
||
is_profile=is_profile)
|
||
if is_profile and self.dynamic_eplb:
|
||
self.model.clear_all_moe_loads()
|
||
if not is_profile and self.dynamic_eplb:
|
||
self.eplb_updator.take_update_info_from_eplb_process()
|
||
self.eplb_updator.forward_end()
|
||
return hidden_states, hidden_states
|
||
|
||
@torch.inference_mode()
|
||
def _dummy_sampler_run(
|
||
self,
|
||
hidden_states: torch.Tensor,
|
||
) -> torch.Tensor:
|
||
output = None
|
||
|
||
# For profile, have maximum num_reqs and that collectively have
|
||
# maximum num_tokens.
|
||
min_tokens_per_req = self.max_num_tokens // self.max_num_reqs
|
||
num_scheduled_tokens_list = [min_tokens_per_req] * self.max_num_reqs
|
||
num_scheduled_tokens_list[
|
||
-1] += self.max_num_tokens % self.max_num_reqs
|
||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||
dtype=np.int32)
|
||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||
# TODO: need to rum a dummy sampler for generate task
|
||
# Sometimes, after the model is compiled through the AOT backend,
|
||
# the model output may become a list containing only one Tensor object.
|
||
if isinstance(hidden_states, list) and \
|
||
len(hidden_states) == 1 and \
|
||
isinstance(hidden_states[0], torch.Tensor):
|
||
hidden_states = hidden_states[0]
|
||
hidden_states = hidden_states[logit_indices]
|
||
output = self.model.compute_logits(hidden_states)
|
||
return output
|
||
|
||
def profile_run(self) -> None:
|
||
mc2_tokens_capacity = get_mc2_tokens_capacity()
|
||
if self.max_num_tokens > mc2_tokens_capacity and \
|
||
select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}:
|
||
self._dummy_run(mc2_tokens_capacity,
|
||
with_prefill=True,
|
||
is_profile=True)
|
||
super().profile_run()
|
||
|
||
def eplb_warmup(self):
|
||
if self.dynamic_eplb and not self.is_eplb_warmuped:
|
||
self.is_eplb_warmuped = True
|
||
self.eplb_adaptor = VllmEplbAdaptor(model=self.model)
|
||
self.eplb_loader.set_adator(self.eplb_adaptor)
|
||
self.eplb_updator.set_adaptor(self.eplb_adaptor)
|
||
self.eplb_updator.warm_up_eplb()
|
||
|
||
def load_model(self) -> None:
|
||
logger.info("Starting to load model %s...", self.model_config.model)
|
||
|
||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||
self.model = get_model(vllm_config=self.vllm_config)
|
||
if self.dynamic_eplb:
|
||
model_register(self.model, self.model_config)
|
||
if self.drafter:
|
||
logger.info("Loading drafter model...")
|
||
self.drafter.load_model(self.model)
|
||
if self.drafter.name == SpecDcodeType.EAGLE3:
|
||
self.model.set_aux_hidden_state_layers(
|
||
self.model.get_eagle3_aux_hidden_state_layers())
|
||
|
||
if self.lora_config:
|
||
self.model = self.load_lora_model(self.model, self.vllm_config,
|
||
self.device)
|
||
logger.info("Loading model weights took %.4f GB",
|
||
m.consumed_memory / float(2**30))
|
||
|
||
# wrap the model with full graph wrapper if needed.
|
||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
||
self.model = ACLGraphWrapper(self.model,
|
||
self.vllm_config,
|
||
runtime_mode=CUDAGraphMode.FULL)
|
||
|
||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||
"""
|
||
Initialize KV cache based on `kv_cache_config`.
|
||
Args:
|
||
kv_cache_config: Configuration for the KV cache, including the KV
|
||
cache size of each layer
|
||
"""
|
||
kv_cache_config = deepcopy(kv_cache_config)
|
||
self.kv_cache_config = kv_cache_config
|
||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||
# NOTE(cmq): initialize_attn_backend must before using self.attn_groups
|
||
self.initialize_attn_backend(kv_cache_config)
|
||
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
|
||
# NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`.
|
||
self.need_accepted_tokens = any([
|
||
isinstance(attn_group[0].kv_cache_spec, MambaSpec)
|
||
for attn_group in self.attn_groups
|
||
])
|
||
|
||
self.may_reinitialize_input_batch(kv_cache_config)
|
||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||
|
||
if has_kv_transfer_group():
|
||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||
|
||
def _align_memory(self, tensor: torch.Tensor,
|
||
alignment: int) -> torch.Tensor:
|
||
data_ptr = tensor.data_ptr()
|
||
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||
return tensor[int(offset):]
|
||
|
||
def initialize_kv_cache_tensors(
|
||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||
"""
|
||
Initialize the memory buffer for KV cache.
|
||
|
||
Args:
|
||
kv_cache_config: The KV cache config
|
||
Returns:
|
||
Dict[str, torch.Tensor]: A map between layer names to their
|
||
corresponding memory buffer for KV cache.
|
||
"""
|
||
# Initialize the memory buffer for KV cache
|
||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
||
# Change the memory buffer to the desired shape
|
||
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
|
||
kv_cache_raw_tensors)
|
||
|
||
from vllm.v1.worker.utils import bind_kv_cache
|
||
bind_kv_cache(kv_caches,
|
||
self.compilation_config.static_forward_context,
|
||
self.kv_caches)
|
||
return kv_caches
|
||
|
||
def _allocate_kv_cache_tensors(
|
||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||
"""
|
||
Initializes the KV cache buffer with the correct size. The buffer needs
|
||
to be reshaped to the desired shape before being used by the models.
|
||
|
||
NOTE: To support prefill disaggregation, we need to split kvcache tensor into
|
||
k_cahce and v cache, and the addr of both are aligned by 2M
|
||
|
||
Args:
|
||
kv_cache_config: The KV cache config
|
||
Returns:
|
||
dict[str, torch.Tensor]: A map between layer names to their
|
||
corresponding memory buffer for KV cache.
|
||
dict[str, tuple(torch.Tensor, torch.Tensor)] A map between layer names
|
||
to their corresponding memory buffer for K cache and V cache.
|
||
"""
|
||
# init kv cache tensors
|
||
kv_cache_raw_tensors: dict[str, Union[torch.Tensor,
|
||
Optional[torch.Tensor]]] = {}
|
||
# prefill disaggregation need the addr of cache tensor be aligned with 2M
|
||
alignment = 2 * 1024 * 1024
|
||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||
# TODO: REFACTOR ME to sharing hybrid cache
|
||
for idx in range(len(kv_cache_tensor.shared_by)):
|
||
layer_name = kv_cache_tensor.shared_by[idx]
|
||
if "linear_attn" in layer_name and layer_name not in kv_cache_raw_tensors.keys(
|
||
):
|
||
# for mamba linear attention
|
||
if self.vllm_config.kv_transfer_config is None:
|
||
tensor = torch.zeros(kv_cache_tensor.size,
|
||
dtype=torch.int8,
|
||
device=self.device)
|
||
else:
|
||
cache_size_aligned = kv_cache_tensor.size + alignment
|
||
tensor = torch.zeros(cache_size_aligned,
|
||
dtype=torch.int8,
|
||
device=self.device)
|
||
tensor = self._align_memory(
|
||
tensor, alignment)[:kv_cache_tensor.size]
|
||
|
||
for layer_name_inner in kv_cache_tensor.shared_by:
|
||
# shared the kvcache between the self_attn specs in the same group
|
||
if "linear_attn" in layer_name_inner:
|
||
kv_cache_raw_tensors[layer_name_inner] = tensor
|
||
elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors.keys(
|
||
):
|
||
# NOTE: We need to init k cache tensor (nope cache tensor in mla) and
|
||
# v cache tensor (rope cache tensor in mla) separately to support prefill disaggregation,
|
||
# as it only support the 0-dim of kv_cache is `num_blocks`.
|
||
# For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim
|
||
# and rope head dim.
|
||
if self.model_config.use_mla:
|
||
head_size = self.model_config.hf_text_config.qk_rope_head_dim + \
|
||
self.model_config.hf_text_config.kv_lora_rank
|
||
|
||
dsa_k_cache_factor = None
|
||
dsa_k_cache_size = None
|
||
if not self.model_config.use_mla:
|
||
# for non-mla model, use FullAttentionSpec
|
||
k_tensor_split_factor = 2
|
||
v_tensor_split_factor = 2
|
||
elif self.use_sparse:
|
||
# for deepseek v3.2, DSA use FullAttentionSpec
|
||
# FullAttentionSpec allocate 2 * mla page size bytes,
|
||
# and we use half of that for k cache in DSA
|
||
dsa_k_cache_factor = 2
|
||
k_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.kv_lora_rank
|
||
v_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.qk_rope_head_dim
|
||
dsa_k_cache_size = int(kv_cache_tensor.size //
|
||
dsa_k_cache_factor)
|
||
else:
|
||
# for other deepseek models, use MLAAttentionSpec
|
||
k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank
|
||
v_tensor_split_factor = head_size / self.model_config.hf_text_config.qk_rope_head_dim
|
||
|
||
k_tensor_size = int(kv_cache_tensor.size //
|
||
k_tensor_split_factor)
|
||
v_tensor_size = int(kv_cache_tensor.size //
|
||
v_tensor_split_factor)
|
||
|
||
# for other attentions, e.g., self_attn, sliding window attn
|
||
if self.vllm_config.kv_transfer_config is None:
|
||
k_tensor = torch.zeros(k_tensor_size,
|
||
dtype=torch.int8,
|
||
device=self.device)
|
||
v_tensor = torch.zeros(v_tensor_size,
|
||
dtype=torch.int8,
|
||
device=self.device)
|
||
#### k cache: for deepseek sparse attention
|
||
if dsa_k_cache_factor is not None:
|
||
dsa_k_cache_tensor = torch.zeros(
|
||
dsa_k_cache_size,
|
||
dtype=torch.int8,
|
||
device=self.device)
|
||
else:
|
||
k_tensor = torch.zeros(k_tensor_size + alignment,
|
||
dtype=torch.int8,
|
||
device=self.device)
|
||
v_tensor = torch.zeros(v_tensor_size + alignment,
|
||
dtype=torch.int8,
|
||
device=self.device)
|
||
k_tensor = self._align_memory(
|
||
k_tensor, alignment)[:k_tensor_size]
|
||
v_tensor = self._align_memory(
|
||
v_tensor, alignment)[:v_tensor_size]
|
||
#### k cache: for deepseek sparse attention
|
||
if dsa_k_cache_factor is not None and dsa_k_cache_size is not None:
|
||
dsa_k_cache_tensor = torch.zeros(
|
||
dsa_k_cache_size + alignment,
|
||
dtype=torch.int8,
|
||
device=self.device)
|
||
dsa_k_cache_tensor = self._align_memory(
|
||
dsa_k_cache_tensor,
|
||
alignment)[:dsa_k_cache_size]
|
||
|
||
for layer_name_inner in kv_cache_tensor.shared_by:
|
||
# shared the kvcache between the self_attn specs in the same group
|
||
if ("attn" in layer_name_inner
|
||
and "linear_attn" not in layer_name_inner):
|
||
kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor) if \
|
||
not self.use_sparse else (k_tensor, v_tensor, dsa_k_cache_tensor)
|
||
|
||
layer_names = set()
|
||
for group in kv_cache_config.kv_cache_groups:
|
||
for layer_name in group.layer_names:
|
||
if layer_name in self.runner_only_attn_layers:
|
||
continue
|
||
layer_names.add(layer_name)
|
||
assert layer_names == set(kv_cache_raw_tensors.keys(
|
||
)), "Some layers are not correctly initialized"
|
||
|
||
return kv_cache_raw_tensors
|
||
|
||
def _reshape_kv_cache_tensors(
|
||
self,
|
||
kv_cache_config: KVCacheConfig,
|
||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||
) -> dict[str, torch.Tensor]:
|
||
"""
|
||
Reshape the KV cache tensors to the desired shape and dtype.
|
||
|
||
Args:
|
||
kv_cache_config: The KV cache config
|
||
kv_cache_raw_tensors: The KV cache buffer of each layer, with
|
||
correct size but uninitialized shape.
|
||
Returns:
|
||
Dict[str, torch.Tensor]: A map between layer names to their
|
||
corresponding memory buffer for KV cache.
|
||
"""
|
||
kv_caches: Dict[str, torch.Tensor] = {}
|
||
for group in self._kv_cache_spec_attn_group_iterator():
|
||
kv_cache_spec = group.kv_cache_spec
|
||
attn_backend = group.backend
|
||
for layer_name in group.layer_names:
|
||
if layer_name in self.runner_only_attn_layers:
|
||
continue
|
||
|
||
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
||
# encounter OOM issue
|
||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
||
raw_dsa_k_tensor = None
|
||
if self.use_sparse:
|
||
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
|
||
layer_name]
|
||
assert raw_dsa_k_tensor is not None
|
||
sum_page_size_bytes = raw_k_tensor.numel(
|
||
) + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
|
||
else:
|
||
raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore
|
||
layer_name]
|
||
sum_page_size_bytes = raw_k_tensor.numel(
|
||
) + raw_v_tensor.numel()
|
||
assert raw_k_tensor is not None
|
||
assert raw_v_tensor is not None
|
||
assert sum_page_size_bytes % kv_cache_spec.page_size_bytes == 0
|
||
num_blocks = sum_page_size_bytes // kv_cache_spec.page_size_bytes
|
||
|
||
# `num_blocks` is the number of blocks the model runner can use.
|
||
# `kv_cache_config.num_blocks` is the number of blocks that
|
||
# KVCacheManager may allocate.
|
||
# Since different GPUs may have different number of layers and
|
||
# different memory capacities, `num_blocks` can be different on
|
||
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
||
# the min of all `num_blocks`. Verify it here.
|
||
assert num_blocks >= kv_cache_config.num_blocks
|
||
|
||
if hasattr(attn_backend, "get_supported_block_size"
|
||
) and self.use_hybrid_blocks:
|
||
block_size = attn_backend.get_supported_block_size()[0]
|
||
|
||
block_size_chunk = kv_cache_spec.block_size // block_size
|
||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||
num_blocks * block_size_chunk, block_size,
|
||
kv_cache_spec.num_kv_heads,
|
||
kv_cache_spec.head_size)
|
||
else:
|
||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||
num_blocks, kv_cache_spec.block_size,
|
||
kv_cache_spec.num_kv_heads,
|
||
kv_cache_spec.head_size)
|
||
dtype = kv_cache_spec.dtype
|
||
if not self.model_config.use_mla:
|
||
k_shape = kv_cache_shape[1:]
|
||
v_shape = k_shape
|
||
else:
|
||
# k_cache: nope_cache v_cache: rope_cache
|
||
mla_num_blocks, mla_block_size, num_kv_heads, _ = kv_cache_shape
|
||
k_shape = [
|
||
mla_num_blocks, mla_block_size, num_kv_heads,
|
||
self.model_config.hf_text_config.kv_lora_rank
|
||
]
|
||
v_shape = [
|
||
mla_num_blocks, mla_block_size, num_kv_heads,
|
||
self.model_config.hf_text_config.qk_rope_head_dim
|
||
]
|
||
k_cache = raw_k_tensor.view(dtype).view(k_shape)
|
||
v_cache = raw_v_tensor.view(dtype).view(v_shape)
|
||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||
k_cache = maybe_trans_nz(k_cache)
|
||
v_cache = maybe_trans_nz(v_cache)
|
||
if self.use_sparse and raw_dsa_k_tensor is not None:
|
||
dsa_k_cache_shape = (num_blocks,
|
||
kv_cache_spec.block_size, 1, 128)
|
||
dsa_k_cache_size = (
|
||
num_blocks
|
||
) * kv_cache_spec.block_size * 128 * dtype.itemsize
|
||
dsa_k_cache = raw_dsa_k_tensor[:dsa_k_cache_size].view(
|
||
dtype).view(dsa_k_cache_shape)
|
||
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
|
||
else:
|
||
kv_caches[layer_name] = (k_cache, v_cache)
|
||
elif isinstance(kv_cache_spec, MambaSpec):
|
||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||
assert raw_tensor is not None
|
||
assert raw_tensor.numel(
|
||
) % kv_cache_spec.page_size_bytes == 0
|
||
num_blocks = raw_tensor.numel(
|
||
) // kv_cache_spec.page_size_bytes
|
||
|
||
# `num_blocks` is the number of blocks the model runner can use.
|
||
# `kv_cache_config.num_blocks` is the number of blocks that
|
||
# KVCacheManager may allocate.
|
||
# Since different GPUs may have different number of layers and
|
||
# different memory capacities, `num_blocks` can be different on
|
||
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
||
# the min of all `num_blocks`. Verify it here.
|
||
assert num_blocks >= kv_cache_config.num_blocks
|
||
|
||
state_tensors = []
|
||
storage_offset_bytes = 0
|
||
for (shape, dtype) in zip(kv_cache_spec.shapes,
|
||
kv_cache_spec.dtypes):
|
||
dtype_size = get_dtype_size(dtype)
|
||
num_element_per_page = (
|
||
kv_cache_spec.page_size_bytes // dtype_size)
|
||
target_shape = (num_blocks, *shape)
|
||
stride = torch.empty(target_shape).stride()
|
||
target_stride = (num_element_per_page, *stride[1:])
|
||
assert storage_offset_bytes % dtype_size == 0
|
||
tensor = torch.as_strided(
|
||
raw_tensor.view(dtype),
|
||
size=target_shape,
|
||
stride=target_stride,
|
||
storage_offset=storage_offset_bytes // dtype_size,
|
||
)
|
||
state_tensors.append(tensor)
|
||
storage_offset_bytes += stride[0] * dtype_size
|
||
kv_caches[layer_name] = state_tensors
|
||
else:
|
||
raise ValueError("Unknown KV cache spec type.")
|
||
|
||
return kv_caches
|
||
|
||
def may_reinitialize_input_batch(self,
|
||
kv_cache_config: KVCacheConfig) -> None:
|
||
"""
|
||
Re-initialize the input batch if the block sizes are different from
|
||
`[self.cache_config.block_size]`. This usually happens when there
|
||
are multiple KV cache groups.
|
||
|
||
Args:
|
||
kv_cache_config: The KV cache configuration.
|
||
"""
|
||
block_sizes = [
|
||
kv_cache_group.kv_cache_spec.block_size
|
||
for kv_cache_group in kv_cache_config.kv_cache_groups
|
||
if not isinstance(kv_cache_group.kv_cache_spec,
|
||
EncoderOnlyAttentionSpec)
|
||
]
|
||
|
||
# Generate kernel_block_sizes that matches each block_size
|
||
# For attention backends that support virtual block splitting,
|
||
# use the supported block sizes from the backend
|
||
# For other backends (like Mamba), use [0] (no splitting)
|
||
kernel_block_sizes = []
|
||
for kv_cache_group_id, kv_cache_group in enumerate(
|
||
kv_cache_config.kv_cache_groups):
|
||
|
||
if isinstance(kv_cache_group.kv_cache_spec,
|
||
EncoderOnlyAttentionSpec):
|
||
continue
|
||
elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
|
||
# This is an attention backend that supports virtual
|
||
# block splitting. Get the supported block sizes from
|
||
# the backend.
|
||
try:
|
||
attn_groups = self.attn_groups[kv_cache_group_id]
|
||
except IndexError:
|
||
attn_groups = None
|
||
if attn_groups and self.use_hybrid_blocks:
|
||
# Use the backend's supported block size list
|
||
backend = attn_groups[0].backend
|
||
supported_sizes = backend.get_supported_block_size()
|
||
# If no specific sizes supported, use cache config
|
||
# block_size
|
||
kernel_block_size_list = (supported_sizes
|
||
if supported_sizes else
|
||
[self.cache_config.block_size])
|
||
else:
|
||
# Fallback to cache config block_size if no backend found
|
||
kernel_block_size_list = [self.cache_config.block_size]
|
||
kernel_block_sizes.append(kernel_block_size_list)
|
||
else:
|
||
# This is likely Mamba or other non-attention cache,
|
||
# no splitting.
|
||
# NOTE: set kernel_block_sizes to 0 to disable slotmapping computation
|
||
# of mamba block. In this case, BlockTable.block_size will never equal
|
||
# to kernel_block_sizes[0]
|
||
kernel_block_sizes.append([0])
|
||
if block_sizes != [
|
||
self.cache_config.block_size
|
||
] or kernel_block_sizes != [[self.cache_config.block_size]]:
|
||
assert self.cache_config.cpu_offload_gb == 0, (
|
||
"Cannot re-initialize the input batch when CPU weight "
|
||
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
||
"for more details.")
|
||
self.input_batch = NPUInputBatch(
|
||
max_num_reqs=self.max_num_reqs,
|
||
max_model_len=self.model_config.max_model_len,
|
||
max_num_batched_tokens=self.max_num_tokens,
|
||
device=self.device,
|
||
pin_memory=self.pin_memory,
|
||
vocab_size=self.model_config.get_vocab_size(),
|
||
block_sizes=block_sizes,
|
||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||
logitsprocs=self.input_batch.logitsprocs,
|
||
is_pooling_model=self.is_pooling_model,
|
||
num_speculative_tokens=(
|
||
self.vllm_config.speculative_config.num_speculative_tokens
|
||
if self.vllm_config.speculative_config else 0),
|
||
kernel_block_sizes=kernel_block_sizes,
|
||
)
|
||
|
||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||
"""
|
||
Initialize the attention backends and attention metadata builders.
|
||
"""
|
||
assert len(self.attn_groups) == 0, \
|
||
"Attention backends are already initialized"
|
||
|
||
class AttentionGroupKey(NamedTuple):
|
||
attn_backend: type[AttentionBackend]
|
||
kv_cache_spec: KVCacheSpec
|
||
|
||
def get_attn_backends_for_group(
|
||
kv_cache_group_spec: KVCacheGroupSpec,
|
||
) -> dict[AttentionGroupKey, list[str]]:
|
||
layers = get_layers_from_vllm_config(
|
||
self.vllm_config, AttentionLayerBase,
|
||
kv_cache_group_spec.layer_names)
|
||
attn_backends = {}
|
||
attn_backend_layers = defaultdict(list)
|
||
# Dedupe based on full class name; this is a bit safer than
|
||
# using the class itself as the key because when we create dynamic
|
||
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
|
||
# they are cached correctly, there will be different objects per
|
||
# layer.
|
||
for layer_name in kv_cache_group_spec.layer_names:
|
||
attn_backend = layers[layer_name].get_attn_backend()
|
||
full_cls_name = attn_backend.full_cls_name()
|
||
layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
|
||
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
|
||
layer_name]
|
||
key = (full_cls_name, layer_kv_cache_spec)
|
||
attn_backends[key] = AttentionGroupKey(attn_backend,
|
||
layer_kv_cache_spec)
|
||
attn_backend_layers[key].append(layer_name)
|
||
return {
|
||
attn_backends[k]: v
|
||
for k, v in attn_backend_layers.items()
|
||
}
|
||
|
||
def create_attn_groups(attn_backends_map: dict[AttentionBackend,
|
||
list[str]],
|
||
kv_cache_group_id: int) -> list[AttentionGroup]:
|
||
attn_groups: list[AttentionGroup] = []
|
||
for (attn_backend,
|
||
kv_cache_spec), layer_names in attn_backends_map.items():
|
||
attn_metadata_builders = []
|
||
attn_metadata_builders.append(attn_backend.get_builder_cls()(
|
||
kv_cache_spec,
|
||
layer_names,
|
||
self.vllm_config,
|
||
self.device,
|
||
))
|
||
attn_group = AttentionGroup(attn_backend, layer_names,
|
||
kv_cache_spec, kv_cache_group_id,
|
||
attn_metadata_builders)
|
||
attn_groups.append(attn_group)
|
||
return attn_groups
|
||
|
||
for i, kv_cache_group_spec in enumerate(
|
||
kv_cache_config.kv_cache_groups):
|
||
attn_backends = get_attn_backends_for_group( # type: ignore
|
||
kv_cache_group_spec)
|
||
self.attn_groups.append(create_attn_groups(attn_backends, i))
|
||
|
||
# Calculate reorder batch threshold (if needed)
|
||
self.calculate_reorder_batch_threshold()
|
||
|
||
def calculate_reorder_batch_threshold(self) -> None:
|
||
"""
|
||
Check that if any backends reorder batches; that the reordering
|
||
is compatible (e.g., decode threshold is the same)
|
||
"""
|
||
for group in self._attn_group_iterator():
|
||
attn_metadata_builder_i = group.get_metadata_builder()
|
||
if hasattr(attn_metadata_builder_i,
|
||
"reorder_batch_threshold"): # noqa
|
||
# check that if any backends reorder batches; that the reordering
|
||
# is compatible (e.g., decode threshold is the same)
|
||
reorder_batch_threshold_i = (
|
||
attn_metadata_builder_i.reorder_batch_threshold)
|
||
if reorder_batch_threshold_i is not None: # noqa
|
||
if self.reorder_batch_threshold is not None:
|
||
if reorder_batch_threshold_i != \
|
||
self.reorder_batch_threshold:
|
||
raise ValueError(
|
||
f"Attention backend reorders decodes with "
|
||
f"threshold {reorder_batch_threshold_i} but other "
|
||
f"backend uses threshold "
|
||
f"{self.reorder_batch_threshold}")
|
||
else:
|
||
self.reorder_batch_threshold = reorder_batch_threshold_i # noqa
|
||
|
||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||
"""
|
||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||
Attention module in the static forward context.
|
||
Returns:
|
||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||
format. Layers that do not need KV cache are not included.
|
||
"""
|
||
|
||
if has_ec_transfer() and get_ec_transfer().is_producer:
|
||
return {}
|
||
|
||
block_size = self.vllm_config.cache_config.block_size
|
||
use_mla = self.vllm_config.model_config.use_mla
|
||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||
attn_layers = get_layers_from_vllm_config(self.vllm_config,
|
||
AttentionLayerBase)
|
||
for layer_name, attn_module in attn_layers.items():
|
||
if isinstance(attn_module, Attention):
|
||
if (kv_tgt_layer :=
|
||
attn_module.kv_sharing_target_layer_name) is not None:
|
||
# The layer doesn't need its own KV cache and will use that of
|
||
# the target layer. We skip creating a KVCacheSpec for it, so
|
||
# that KV cache management logic will act as this layer does
|
||
# not exist, and doesn't allocate KV cache for the layer. This
|
||
# enables the memory saving of cross-layer kv sharing, allowing
|
||
# a given amount of memory to accommodate longer context lengths
|
||
# or enable more requests to be processed simultaneously.
|
||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||
continue
|
||
|
||
# TODO: Support other attention modules, e.g., cross-attention
|
||
# TODO(lucas): move the attention specs into the model layers like
|
||
# the attention backends
|
||
if attn_module.attn_type == AttentionType.DECODER:
|
||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||
block_size=block_size,
|
||
num_kv_heads=attn_module.num_kv_heads,
|
||
head_size=attn_module.head_size,
|
||
dtype=self.kv_cache_dtype)
|
||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||
AttentionType.ENCODER_ONLY):
|
||
# encoder-only attention does not need KV cache.
|
||
continue
|
||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||
raise NotImplementedError
|
||
else:
|
||
raise ValueError(
|
||
f"Unknown attention type: {attn_module.attn_type}")
|
||
|
||
elif isinstance(attn_module, MLAAttention):
|
||
if use_mla and not self.use_sparse:
|
||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||
block_size=block_size,
|
||
num_kv_heads=1,
|
||
head_size=attn_module.head_size,
|
||
dtype=self.kv_cache_dtype,
|
||
cache_dtype_str=self.cache_config.cache_dtype)
|
||
else:
|
||
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
||
# using DSA. Fix the spec in vLLM is a finnal way.
|
||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||
block_size=block_size,
|
||
num_kv_heads=1,
|
||
head_size=attn_module.head_size,
|
||
dtype=self.kv_cache_dtype)
|
||
|
||
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
|
||
if len(mamba_layers) > 0:
|
||
if (self.vllm_config.speculative_config is not None
|
||
and self.vllm_config.model_config.hf_config.model_type
|
||
not in ["qwen3_next"]):
|
||
raise NotImplementedError(
|
||
"Mamba with speculative decoding is not supported yet.")
|
||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||
raise NotImplementedError(
|
||
"Prefix caching is not supported for Mamba yet.")
|
||
max_model_len = self.vllm_config.model_config.max_model_len
|
||
|
||
page_size_padded = (
|
||
self.vllm_config.cache_config.mamba_page_size_padded)
|
||
|
||
# Set block_size to max_model_len, so that mamba model will always
|
||
# have only one block in the KV cache.
|
||
for layer_name, mamba_module in mamba_layers.items():
|
||
kv_cache_spec[layer_name] = MambaSpec(
|
||
shapes=mamba_module.get_state_shape(),
|
||
dtypes=mamba_module.get_state_dtype(),
|
||
block_size=max_model_len,
|
||
page_size_padded=page_size_padded,
|
||
mamba_type=mamba_module.mamba_type,
|
||
num_speculative_blocks=(
|
||
self.speculative_config.num_speculative_tokens
|
||
if self.speculative_config else 0),
|
||
)
|
||
|
||
return kv_cache_spec
|
||
|
||
def initialize_aclgraph_capture(self) -> None:
|
||
min_ag_support = AttentionCGSupport.ALWAYS
|
||
min_ag_builder_name = None
|
||
|
||
for attn_group in self._attn_group_iterator():
|
||
builder = attn_group.get_metadata_builder()
|
||
graph_support = None
|
||
if hasattr(builder, 'aclgraph_support'):
|
||
graph_support = builder.aclgraph_support.value
|
||
builder_aclgraph = builder.aclgraph_support
|
||
else:
|
||
graph_support = builder._cudagraph_support.value
|
||
builder_aclgraph = builder._cudagraph_support
|
||
if graph_support < min_ag_support.value:
|
||
min_ag_support = builder_aclgraph
|
||
min_ag_builder_name = builder.__class__.__name__
|
||
|
||
# This is an imitation of compilation_config.splitting_ops_contain_attention()
|
||
splitting_ops_contain_attention = (
|
||
self.compilation_config.splitting_ops is not None
|
||
and all(op in self.compilation_config.splitting_ops for op in [
|
||
"vllm.mla_forward",
|
||
]))
|
||
|
||
# Flexible resolve the aclgraph mode
|
||
aclgraph_mode = self.compilation_config.cudagraph_mode
|
||
# check graph for mixed batch is supported
|
||
if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \
|
||
and min_ag_support != AttentionCGSupport.ALWAYS:
|
||
msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported "
|
||
f"with {min_ag_builder_name} backend (support: "
|
||
f"{min_ag_support})")
|
||
if min_ag_support == AttentionCGSupport.NEVER:
|
||
# if not supported any full graphs, just raise it.
|
||
msg += "; please try cudagraph_mode=PIECEWISE, and "\
|
||
"make sure compilation level is piecewise"
|
||
raise ValueError(msg)
|
||
|
||
# attempt to resolve the full graph related mode
|
||
if splitting_ops_contain_attention:
|
||
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
|
||
aclgraph_mode = self.compilation_config.cudagraph_mode = (
|
||
CUDAGraphMode.FULL_AND_PIECEWISE)
|
||
else:
|
||
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
|
||
aclgraph_mode = self.compilation_config.cudagraph_mode = (
|
||
CUDAGraphMode.FULL_DECODE_ONLY)
|
||
logger.warning(msg)
|
||
|
||
# double check that we can support full graph if they are requested
|
||
# even after automatic downgrades
|
||
if aclgraph_mode.has_full_cudagraphs() \
|
||
and min_ag_support == AttentionCGSupport.NEVER:
|
||
raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not "
|
||
f"supported with {min_ag_builder_name} backend ("
|
||
f"support:{min_ag_support}) "
|
||
"; please try cudagraph_mode=PIECEWISE, "
|
||
"and make sure compilation level is piecewise")
|
||
|
||
if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||
and aclgraph_mode.separate_routine()
|
||
and self.uniform_decode_query_len > 1):
|
||
self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
|
||
self.uniform_decode_query_len,
|
||
self.parallel_config.tensor_parallel_size)
|
||
capture_sizes = self.compilation_config.cudagraph_capture_sizes
|
||
self.cudagraph_batch_sizes = (capture_sizes
|
||
if capture_sizes is not None else [])
|
||
|
||
# NOTE: Since aclgraph_batch_sizes cannot be determined until here,
|
||
# we set the graph params right before initializing the keys.
|
||
set_graph_params(self.cudagraph_batch_sizes)
|
||
if self.speculative_config:
|
||
set_mtp_graph_params(self.cudagraph_batch_sizes)
|
||
|
||
self.cudagraph_dispatcher.initialize_cudagraph_keys(
|
||
self.compilation_config.cudagraph_mode,
|
||
self.uniform_decode_query_len)
|
||
|
||
def _capture_aclgraphs(self, compilation_cases: list[int],
|
||
aclgraph_runtime_mode: CUDAGraphMode,
|
||
uniform_decode: bool):
|
||
assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
|
||
aclgraph_runtime_mode in [CUDAGraphMode.FULL,
|
||
CUDAGraphMode.PIECEWISE]
|
||
|
||
# Only rank 0 should print progress bar during capture
|
||
if is_global_first_rank():
|
||
logger.info(
|
||
"Starting to capture ACL graphs for cases: %s, "
|
||
"mode: %s, uniform_decode: %s", compilation_cases,
|
||
aclgraph_runtime_mode.name, uniform_decode)
|
||
compilation_cases = tqdm(
|
||
compilation_cases,
|
||
disable=not self.load_config.use_tqdm_on_load,
|
||
desc="Capturing ACL graphs ({}, {})".format(
|
||
"decode" if uniform_decode else "mixed prefill-decode",
|
||
aclgraph_runtime_mode.name))
|
||
|
||
force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL)
|
||
# When the kv cache spec is empty, PiecewiseBackend is not initialized, and
|
||
# compilation_case=1 will cause the dynamic shape position to be incorrectly derived.
|
||
if not self.get_kv_cache_spec():
|
||
self._dummy_run(2,
|
||
aclgraph_runtime_mode=CUDAGraphMode.NONE,
|
||
force_attention=force_attention,
|
||
uniform_decode=uniform_decode)
|
||
# We skip EPLB here since we don't want to record dummy metrics
|
||
for num_tokens in compilation_cases:
|
||
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
||
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
||
# But be careful, warm up with `NONE`is orthogonal to
|
||
# if we want to warm up attention or not. This is
|
||
# different from the case where `FULL` implies capture
|
||
# attention while `PIECEWISE` implies no attention.
|
||
self._dummy_run(num_tokens,
|
||
aclgraph_runtime_mode=CUDAGraphMode.NONE,
|
||
force_attention=force_attention,
|
||
uniform_decode=uniform_decode)
|
||
self._dummy_run(num_tokens,
|
||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||
force_attention=force_attention,
|
||
uniform_decode=uniform_decode)
|
||
|
||
def _capture_model(self):
|
||
if not self.use_aclgraph:
|
||
logger.warning(
|
||
"Skipping ACL graph capture. To turn on ACL graph capture, "
|
||
"ensure `aclraph_mode` was not manually set to `NONE`")
|
||
return
|
||
else:
|
||
self.initialize_aclgraph_capture()
|
||
|
||
set_cudagraph_capturing_enabled(True)
|
||
# Trigger ACL graph capture for specific shapes.
|
||
# Capture the large shapes first so that the smaller shapes
|
||
# can reuse the memory pool allocated for the large shapes.
|
||
with graph_capture(device=self.device):
|
||
aclgraph_mode = self.compilation_config.cudagraph_mode
|
||
if aclgraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||
aclgraph_runtime_mode = aclgraph_mode.mixed_mode()
|
||
|
||
# make sure we capture the largest batch size first
|
||
compilation_cases = list(reversed(self.cudagraph_batch_sizes))
|
||
|
||
try:
|
||
self._capture_aclgraphs(
|
||
compilation_cases,
|
||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||
uniform_decode=False)
|
||
except Exception as e:
|
||
error_msg = str(e)
|
||
error_code = '0x7020023'
|
||
pattern = r'retCode=([^,\s\.]+)'
|
||
match = re.search(pattern, error_msg)
|
||
if match:
|
||
retCode = match.group(1)
|
||
# Determine whether the error message is caused by stream capture failure.
|
||
if match and retCode == error_code:
|
||
logger.error(
|
||
f"ACLgraph sizes capture fail: {type(e).__name__}:\n"
|
||
"ACLgraph has insufficient available streams to capture the configured number of sizes. "
|
||
"Please verify both the availability of adequate streams and the appropriateness of the configured size count.\n\n"
|
||
"Recommended solutions:\n"
|
||
"1. Manually configure the compilation_config parameter "
|
||
"with a reduced set of sizes: '{\"cudagraph_capture_sizes\":[size1, size2, size3, ...]}'.\n"
|
||
"2. Utilize ACLgraph's full graph mode as an alternative to the piece-wise approach.\n\n"
|
||
f"{str(e)}")
|
||
raise
|
||
|
||
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
||
aclgraph_mode.separate_routine():
|
||
max_num_tokens = self.scheduler_config.max_num_seqs * \
|
||
self.uniform_decode_query_len
|
||
decode_cudagraph_batch_sizes = [
|
||
x for x in self.cudagraph_batch_sizes if
|
||
x <= max_num_tokens and x >= self.uniform_decode_query_len
|
||
]
|
||
compilation_cases_decode = list(
|
||
reversed(decode_cudagraph_batch_sizes))
|
||
self._capture_aclgraphs(
|
||
compilation_cases=compilation_cases_decode,
|
||
aclgraph_runtime_mode=CUDAGraphMode.FULL,
|
||
uniform_decode=True)
|
||
|
||
# Disable aclgraph capturing globally, so any unexpected aclgraph
|
||
# capturing will be detected and raise an error after here.
|
||
# Note: We don't put it into graph_capture context manager because
|
||
# we may doing lazy capturing in future that still allows capturing
|
||
# after here.
|
||
set_cudagraph_capturing_enabled(False)
|
||
|
||
def capture_model(self) -> None:
|
||
|
||
compilation_counter.num_gpu_runner_capture_triggers += 1
|
||
|
||
start_time = time.perf_counter()
|
||
start_free_npu_memory = torch.npu.mem_get_info()[0]
|
||
|
||
self._capture_model()
|
||
|
||
end_time = time.perf_counter()
|
||
end_free_npu_memory = torch.npu.mem_get_info()[0]
|
||
elapsed_time = end_time - start_time
|
||
npu_graph_size = start_free_npu_memory - end_free_npu_memory
|
||
# This usually takes 5~20 seconds.
|
||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||
elapsed_time, npu_graph_size / (1 << 30))
|
||
|
||
def _update_tokens_for_pcp(self, tokens):
|
||
num_reqs = self.input_batch.num_reqs
|
||
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
|
||
tokens = np.array(tokens, dtype=np.int32)
|
||
num_decode_reqs = sum(
|
||
self.input_batch.num_computed_tokens_cpu[:num_reqs] >=
|
||
self.input_batch.num_prompt_tokens[:num_reqs])
|
||
num_decode_tokens = sum(tokens[:num_decode_reqs])
|
||
num_padded_scheduled_tokens = np.ceil(
|
||
tokens /
|
||
(2 * self.pcp_size)).astype(np.int32) * (2 * self.pcp_size)
|
||
num_padded_scheduled_tokens[:num_decode_reqs] = (
|
||
tokens[:num_decode_reqs] * self.pcp_size)
|
||
self.num_pcp_pads = torch.tensor(num_padded_scheduled_tokens - tokens)
|
||
cu_padded_tokens, pcp_padded_arange = \
|
||
self._get_cumsum_and_arange(num_padded_scheduled_tokens)
|
||
unpad_mask = torch.from_numpy(
|
||
pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens))
|
||
unpad_mask_decode = unpad_mask[:num_decode_tokens * self.pcp_size]
|
||
unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_size])
|
||
unpad_mask_decode[:, 0] = True
|
||
unpad_mask_decode[:, 1:] = False
|
||
|
||
pcp_tokens = num_padded_scheduled_tokens // self.pcp_size
|
||
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
|
||
pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs]
|
||
_, pcp_arange = self._get_cumsum_and_arange(pcp_tokens)
|
||
_, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes)
|
||
pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes,
|
||
pcp_tokens)
|
||
|
||
def get_current_rank_positions(cu_tokens, rank):
|
||
positions_start_loc = np.zeros_like(cu_tokens)
|
||
positions_start_loc[1:] = cu_tokens[:-1]
|
||
positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32)
|
||
head_start_loc = positions_start_loc + rank * pcp_chunk_sizes
|
||
tail_start_loc = positions_start_loc + \
|
||
(2 * self.pcp_size - rank - 1) * pcp_chunk_sizes
|
||
positions[pcp_head_chunk_mask] = pcp_chunk_arange + \
|
||
np.repeat(head_start_loc, pcp_chunk_sizes)
|
||
# Decode reqs do not have tail chunks.
|
||
positions[~pcp_head_chunk_mask] = \
|
||
pcp_chunk_arange[num_decode_tokens:] + \
|
||
np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:]
|
||
return positions
|
||
|
||
positions = get_current_rank_positions(
|
||
np.zeros(num_reqs, dtype=np.int32), self.pcp_rank)
|
||
# Decode tokens are duplicate and their positions always be 0.
|
||
if num_decode_reqs > 0:
|
||
positions[:num_decode_tokens] = self._get_cumsum_and_arange(
|
||
tokens[:num_decode_reqs])[1]
|
||
|
||
all_positions = [
|
||
get_current_rank_positions(cu_padded_tokens, rank_i)
|
||
for rank_i in range(self.pcp_size)
|
||
]
|
||
all_positions_tensor = torch.from_numpy(np.concatenate(all_positions))
|
||
self.pcp_allgather_restore_idx[:all_positions_tensor.shape[0]].copy_(
|
||
all_positions_tensor.float().argsort().long(), non_blocking=True)
|
||
return pcp_tokens, positions, unpad_mask
|
||
|
||
def _get_cp_local_seq_lens(
|
||
self,
|
||
seq_lens: torch.Tensor,
|
||
pcp_world_size: int = 1,
|
||
dcp_world_size: int = 1,
|
||
cp_kv_cache_interleave_size: int = 1,
|
||
) -> torch.Tensor:
|
||
"""While using pcp or dcp, kv_cache size stored on each rank may be different,
|
||
use this function to calculate split decode seq_lens of each (p/d)cp rank.
|
||
"""
|
||
num_requests = seq_lens.size(0)
|
||
total_world_size = pcp_world_size * dcp_world_size
|
||
seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size)
|
||
rank_offsets = (torch.arange(total_world_size,
|
||
dtype=torch.int32).unsqueeze(0).repeat(
|
||
num_requests, 1))
|
||
base = (seq_lens_tiled // cp_kv_cache_interleave_size //
|
||
total_world_size * cp_kv_cache_interleave_size)
|
||
remainder = seq_lens_tiled - base * total_world_size
|
||
remainder = torch.clip(
|
||
remainder - rank_offsets * cp_kv_cache_interleave_size,
|
||
0,
|
||
cp_kv_cache_interleave_size,
|
||
)
|
||
dcp_local_seq_lens = (base + remainder).reshape(
|
||
[-1, pcp_world_size, dcp_world_size])
|
||
return dcp_local_seq_lens
|
||
|
||
def _generate_pcp_metadata(self, total_num_scheduled_tokens):
|
||
# In dummy run num_reqs == 0, update it from seq_lens
|
||
num_reqs = self.input_batch.num_reqs or self.query_lens.size(0)
|
||
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
|
||
>= self.input_batch.num_prompt_tokens[:num_reqs])
|
||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
|
||
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
||
long_seq_metadata = None
|
||
if self.pcp_size * self.dcp_size > 1:
|
||
decode_context_lens = self.input_batch.num_tokens[:num_decodes]
|
||
prefill_context_lens = self.input_batch.num_computed_tokens_cpu[
|
||
num_decodes:num_reqs]
|
||
context_lens = np.concatenate(
|
||
[decode_context_lens, prefill_context_lens])
|
||
num_computed_tokens_of_pcp_dcp = torch.zeros(
|
||
[
|
||
num_reqs * self.decode_threshold, self.pcp_size,
|
||
self.dcp_size
|
||
],
|
||
dtype=torch.int32,
|
||
)
|
||
# For pcp + spec decode, we flatten seq_lens
|
||
# to avoid irregular spec_attn_mask shape
|
||
for decode_idx in range(self.decode_threshold):
|
||
num_computed_tokens_of_pcp_dcp[
|
||
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
|
||
self._get_cp_local_seq_lens(
|
||
torch.tensor(context_lens),
|
||
self.pcp_size,
|
||
self.dcp_size,
|
||
self.parallel_config.cp_kv_cache_interleave_size,
|
||
)
|
||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
|
||
numpy())
|
||
if self.pcp_size > 1:
|
||
q_head_idx, q_tail_idx = [], []
|
||
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
||
kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], []
|
||
chunk_seqlens = []
|
||
kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], []
|
||
q_req_offset = 0
|
||
kv_req_offset = 0
|
||
q_head_chunk_id = self.pcp_rank
|
||
q_tail_chunk_id = self.pcp_size * 2 - 1 - self.pcp_rank
|
||
for i, seq_len in enumerate(self.query_lens):
|
||
if i < num_decodes:
|
||
continue
|
||
chunk_len = seq_len // 2
|
||
chunk_seqlens.append(chunk_len)
|
||
q_head_idx.extend(
|
||
list(range(q_req_offset, q_req_offset + chunk_len)))
|
||
kv_with_q_head_nomask_idx.extend(
|
||
list(
|
||
range(kv_req_offset, kv_req_offset +
|
||
chunk_len * q_head_chunk_id)))
|
||
kv_with_q_head_mask_idx.extend(
|
||
list(
|
||
range(
|
||
kv_req_offset + chunk_len * q_head_chunk_id,
|
||
kv_req_offset + chunk_len *
|
||
(q_head_chunk_id + 1))))
|
||
kv_with_q_head_nomask_seqlens.append(chunk_len *
|
||
q_head_chunk_id)
|
||
|
||
q_tail_idx.extend(
|
||
list(
|
||
range(q_req_offset + chunk_len,
|
||
q_req_offset + chunk_len * 2)))
|
||
kv_with_q_tail_nomask_idx.extend(
|
||
list(
|
||
range(kv_req_offset, kv_req_offset +
|
||
chunk_len * q_tail_chunk_id)))
|
||
kv_with_q_tail_mask_idx.extend(
|
||
list(
|
||
range(
|
||
kv_req_offset + chunk_len * q_tail_chunk_id,
|
||
kv_req_offset + chunk_len *
|
||
(q_tail_chunk_id + 1))))
|
||
kv_with_q_tail_nomask_seqlens.append(chunk_len *
|
||
q_tail_chunk_id)
|
||
|
||
q_req_offset += seq_len
|
||
kv_req_offset += seq_len * self.pcp_size
|
||
|
||
# Convert lists to tensors and move to device
|
||
def _list_to_tensor(lst, device, dtype=torch.int32):
|
||
tensor_npu = torch.zeros(len(lst),
|
||
dtype=dtype,
|
||
device=device)
|
||
tensor_npu.copy_(torch.tensor(lst, dtype=dtype),
|
||
non_blocking=True)
|
||
return tensor_npu
|
||
|
||
q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device)
|
||
q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device)
|
||
self.q_head_idx_tensor = q_head_idx_tensor
|
||
self.q_tail_idx_tensor = q_tail_idx_tensor
|
||
|
||
q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor])
|
||
q_full_idx = q_full_idx.to(torch.float32).argsort().to(
|
||
torch.int32)
|
||
self.q_full_idx = q_full_idx
|
||
|
||
self.kv_idx_names = {
|
||
'kv_with_q_head_nomask_idx_tensor':
|
||
kv_with_q_head_nomask_idx,
|
||
'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx,
|
||
'kv_with_q_tail_nomask_idx_tensor':
|
||
kv_with_q_tail_nomask_idx,
|
||
'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx
|
||
}
|
||
for key, value in self.kv_idx_names.items():
|
||
tensor_npu = _list_to_tensor(value, self.device)
|
||
self.kv_idx_names[key] = tensor_npu
|
||
|
||
attn_mask_seqlens = torch.tensor(
|
||
[chunk_seqlens, chunk_seqlens], dtype=torch.int32)
|
||
head_attn_nomask_seqlens = torch.tensor(
|
||
[chunk_seqlens, kv_with_q_head_nomask_seqlens],
|
||
dtype=torch.int32)
|
||
tail_attn_nomask_seqlens = torch.tensor(
|
||
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
|
||
dtype=torch.int32)
|
||
pcp_prefill_mask = self.attn_mask
|
||
|
||
self.extra_long_seq_kwargs = {
|
||
'attn_mask_seqlens': attn_mask_seqlens,
|
||
'head_attn_nomask_seqlens': head_attn_nomask_seqlens,
|
||
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens,
|
||
'pcp_prefill_mask': pcp_prefill_mask
|
||
}
|
||
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx[:
|
||
num_actual_tokens_pcp_padded]
|
||
long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk
|
||
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
|
||
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
|
||
long_seq_metadata.q_full_idx = self.q_full_idx
|
||
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[
|
||
'kv_with_q_head_nomask_idx_tensor']
|
||
long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[
|
||
'kv_with_q_head_mask_idx_tensor']
|
||
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[
|
||
'kv_with_q_tail_nomask_idx_tensor']
|
||
long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[
|
||
'kv_with_q_tail_mask_idx_tensor']
|
||
long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[
|
||
'attn_mask_seqlens']
|
||
long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[
|
||
'head_attn_nomask_seqlens']
|
||
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
|
||
'tail_attn_nomask_seqlens']
|
||
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
|
||
'pcp_prefill_mask']
|
||
self.long_seq_metadata = long_seq_metadata
|
||
return long_seq_metadata
|
||
|
||
def _generate_pcp_mtp_input(
|
||
self,
|
||
num_reqs: int,
|
||
total_num_scheduled_tokens: int,
|
||
num_scheduled_tokens: dict[str, int],
|
||
):
|
||
"""
|
||
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
|
||
but mtp need to shift original input_ids before pcp splitting,
|
||
so we record original input_ids here.
|
||
"""
|
||
total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens
|
||
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
|
||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
|
||
req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs],
|
||
num_scheduled_tokens_pcp_full)
|
||
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
|
||
self.query_start_loc_pcp_full.np[0] = 0
|
||
self.query_start_loc_pcp_full.np[1:num_reqs +
|
||
1] = cu_num_tokens_pcp_full
|
||
self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1)
|
||
cumsums_offsets_pcp_full = np.repeat(
|
||
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full,
|
||
num_scheduled_tokens_pcp_full)
|
||
arange_pcp_full = self.arange_np[:
|
||
total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
|
||
positions_pcp_full_np = self.positions_pcp_full_np[:
|
||
total_num_scheduled_tokens_pcp_full]
|
||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
|
||
arange_pcp_full,
|
||
out=positions_pcp_full_np)
|
||
token_indices_pcp_full = (
|
||
positions_pcp_full_np +
|
||
req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1])
|
||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||
0,
|
||
torch.from_numpy(token_indices_pcp_full),
|
||
out=self.input_ids_pcp_full.
|
||
cpu[:total_num_scheduled_tokens_pcp_full])
|
||
self.query_start_loc_pcp_full.copy_to_gpu()
|
||
self.input_ids_pcp_full.gpu[:total_num_scheduled_tokens_pcp_full].copy_(
|
||
self.input_ids_pcp_full.cpu[:total_num_scheduled_tokens_pcp_full],
|
||
non_blocking=True,
|
||
)
|
||
|
||
|
||
@contextmanager
|
||
def _torch_cuda_wrapper():
|
||
|
||
class _EventPlaceholder:
|
||
|
||
def __init__(self, *args, **kwargs) -> None:
|
||
self.record = lambda: None
|
||
self.synchronize = lambda: None
|
||
|
||
class _StreamPlaceholder:
|
||
|
||
def __init__(self, *args, **kwargs) -> None:
|
||
pass
|
||
|
||
try:
|
||
# replace cuda APIs with xpu APIs, this should work by default
|
||
torch.Event = torch.npu.Event
|
||
torch.cuda.Event = torch.npu.Event
|
||
torch.cuda.Stream = torch.npu.Stream
|
||
torch.cuda.default_stream = torch.npu.default_stream
|
||
torch.cuda.current_stream = torch.npu.current_stream
|
||
torch.cuda.stream = torch.npu.stream
|
||
yield
|
||
except Exception:
|
||
torch.cuda.Event = _EventPlaceholder
|
||
torch.cuda.Stream = _StreamPlaceholder
|
||
torch.cuda.default_stream = _StreamPlaceholder
|
||
torch.cuda.current_stream = _StreamPlaceholder
|
||
torch.cuda.stream = _StreamPlaceholder
|
||
finally:
|
||
# if anything goes wrong, just patch it with a placeholder
|
||
torch.cuda.Event = _EventPlaceholder
|
||
torch.cuda.Stream = torch.cuda.Stream
|
||
torch.cuda.default_stream = torch.npu.default_stream
|
||
torch.cuda.current_stream = torch.npu.current_stream
|
||
torch.cuda.stream = torch.npu.stream
|