### What this PR does / why we need it? Since running on Altlas 300I Duo was initial supported after #1333 , this PR will disable the JIT compiler for the 310P and changed the data format to NZ for the weight in the vocabulary embedding and QKV projection layers, which help improving performance. See #1563 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Test manually: https://github.com/vllm-project/vllm-ascend/pull/1591#issuecomment-3028352339 Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
2546 lines
119 KiB
Python
2546 lines
119 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
|
#
|
|
|
|
import gc
|
|
import os
|
|
import time
|
|
import types
|
|
import weakref
|
|
from contextlib import contextmanager, nullcontext
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
import torch
|
|
import torch._dynamo.cache_size
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.distributed import ReduceOp
|
|
from vllm.attention import AttentionType, get_attn_backend
|
|
from vllm.attention.layer import Attention
|
|
from vllm.config import CompilationLevel, VllmConfig
|
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
from vllm.distributed.parallel_state import get_dp_group, get_pp_group
|
|
from vllm.forward_context import set_forward_context
|
|
from vllm.inputs import INPUT_REGISTRY
|
|
from vllm.logger import logger
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
|
from vllm.model_executor.model_loader import get_model
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
|
from vllm.pooling_params import PoolingParams
|
|
from vllm.sampling_params import SamplingType
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|
LazyLoader, cdiv)
|
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
KVCacheSpec)
|
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
|
ModelRunnerOutput)
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
from vllm.v1.sample.sampler import Sampler
|
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
|
from vllm.v1.utils import bind_kv_cache
|
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|
from vllm.v1.worker.utils import (gather_mm_placeholders,
|
|
sanity_check_mm_encoder_outputs,
|
|
scatter_mm_placeholders)
|
|
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
from vllm_ascend.attention.attention import AttentionMaskBuilder
|
|
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
|
AscendMetadata)
|
|
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
|
|
from vllm_ascend.platform import NPUPlatform
|
|
from vllm_ascend.pool.metadata import PoolingMetadata
|
|
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
|
ProfileExecuteDuration, is_310p,
|
|
maybe_converting_weight_acl_format,
|
|
vllm_version_is)
|
|
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
|
|
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
|
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
|
|
|
if TYPE_CHECKING:
|
|
import xgrammar as xgr # type: ignore[import-untyped]
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
else:
|
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
|
|
|
import torch_npu
|
|
import vllm.envs as envs_vllm
|
|
|
|
import vllm_ascend.envs as envs_ascend
|
|
|
|
if vllm_version_is("0.9.1"):
|
|
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
|
|
|
if is_310p():
|
|
torch_npu.npu.set_compile_mode(jit_compile=False)
|
|
|
|
|
|
@dataclass
|
|
class GraphCaptureContext:
|
|
stream: torch.npu.Stream
|
|
|
|
|
|
@contextmanager
|
|
def graph_capture(device: torch.device):
|
|
"""
|
|
`graph_capture` is a context manager which should surround the code that
|
|
is capturing the NPU graph. Its main purpose is to ensure that the
|
|
some operations will be run after the graph is captured, before the graph
|
|
is replayed. It returns a `GraphCaptureContext` object which contains the
|
|
necessary data for the graph capture. Currently, it only contains the
|
|
stream that the graph capture is running on. This stream is set to the
|
|
current NPU stream when the context manager is entered and reset to the
|
|
default stream when the context manager is exited. This is to ensure that
|
|
the graph capture is running on a separate stream from the default stream,
|
|
in order to explicitly distinguish the kernels to capture
|
|
from other kernels possibly launched on background in the default stream.
|
|
"""
|
|
graph_capture_context = GraphCaptureContext(
|
|
torch.npu.Stream(device=device))
|
|
stream = graph_capture_context.stream
|
|
|
|
# we use nullcontext now
|
|
maybe_ca_context = nullcontext()
|
|
|
|
# ensure all initialization operations complete before attempting to
|
|
# capture the graph on another stream
|
|
curr_stream = torch.npu.current_stream()
|
|
if curr_stream != stream:
|
|
stream.wait_stream(curr_stream)
|
|
|
|
with torch.npu.stream(stream), maybe_ca_context:
|
|
yield graph_capture_context
|
|
|
|
|
|
class NPUModelRunner(LoRAModelRunnerMixin):
|
|
|
|
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
self.cache_config = vllm_config.cache_config
|
|
self.lora_config = vllm_config.lora_config
|
|
self.scheduler_config = vllm_config.scheduler_config
|
|
self.speculative_config = vllm_config.speculative_config
|
|
self.block_size = vllm_config.cache_config.block_size
|
|
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
|
self.block_size)
|
|
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
|
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
|
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
|
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
|
self.device = device
|
|
self.dtype = self.model_config.dtype
|
|
self.sampler = Sampler()
|
|
# Multi-modal data support
|
|
self.input_registry = INPUT_REGISTRY
|
|
self.mm_registry = MULTIMODAL_REGISTRY
|
|
self.max_num_encoder_input_tokens, self.encoder_cache_size = compute_encoder_budget(
|
|
model_config=self.model_config,
|
|
scheduler_config=self.scheduler_config,
|
|
mm_registry=self.mm_registry)
|
|
|
|
# Lazy initialization, these will be set after __init__
|
|
self.kv_caches: List[torch.Tensor] = []
|
|
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
|
self.attn_mask = None
|
|
self.attn_state = None
|
|
self.requests: Dict[str, CachedRequestState] = {}
|
|
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
|
|
|
ascend_config = get_ascend_config()
|
|
if ascend_config.ascend_scheduler_config.enabled:
|
|
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
|
|
else:
|
|
self.chunked_prefill_enabled = True
|
|
|
|
if self.cache_config.cache_dtype == "auto":
|
|
self.kv_cache_dtype = self.dtype
|
|
else:
|
|
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
|
self.cache_config.cache_dtype]
|
|
|
|
self.is_multimodal_model = self.model_config.is_multimodal_model
|
|
self.is_pooling_model = self.model_config.pooler_config is not None
|
|
if self.is_multimodal_model:
|
|
self.inputs_embeds = torch.zeros(
|
|
(self.max_num_tokens, self.model_config.get_hidden_size()),
|
|
dtype=self.dtype,
|
|
device=self.device)
|
|
|
|
self.graph_block_tables = np.zeros(
|
|
(self.max_num_reqs,
|
|
(self.model_config.max_model_len + self.block_size - 1) //
|
|
self.block_size),
|
|
dtype=np.int32)
|
|
|
|
# Set up Attention
|
|
self.attn_backend = get_attn_backend(
|
|
0,
|
|
self.dtype,
|
|
None,
|
|
self.block_size,
|
|
self.model_config.is_attention_free,
|
|
use_mla=self.model_config.use_mla,
|
|
)
|
|
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
|
weakref.proxy(self))
|
|
|
|
# Set up speculative decoding.
|
|
self.use_aux_hidden_state_outputs = False
|
|
self.use_spec_decode = False
|
|
self.spec_attn_mask = None
|
|
self.use_eagle = False
|
|
self.drafter = None
|
|
if self.speculative_config:
|
|
self.use_spec_decode = True
|
|
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
|
2048,
|
|
dtype=torch.bool),
|
|
diagonal=1).to("npu")
|
|
if get_pp_group().is_last_rank:
|
|
if self.speculative_config.method == "ngram":
|
|
self.drafter = NgramProposer(self.vllm_config)
|
|
elif self.speculative_config.method in ["eagle", "eagle3"]:
|
|
self.use_eagle = True
|
|
self.drafter = EagleProposer(self.vllm_config, self.device,
|
|
self) # type: ignore
|
|
if self.speculative_config.method == "eagle3":
|
|
self.use_aux_hidden_state_outputs = True
|
|
elif self.speculative_config.method == 'deepseek_mtp':
|
|
self.drafter = MtpProposer(self.vllm_config, self)
|
|
else:
|
|
raise ValueError("Unknown speculative decoding method: "
|
|
f"{self.speculative_config.method}")
|
|
self.rejection_sampler = AscendRejectionSampler()
|
|
|
|
self.input_ids = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
self.positions = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int64,
|
|
device=self.device)
|
|
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
self.seq_lens = torch.zeros(self.max_num_reqs,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
|
|
self.uses_mrope = self.model_config.uses_mrope
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
if self.uses_mrope:
|
|
# NOTE: `mrope_positions` is implemented with one additional dummy
|
|
# position on purpose to make it non-contiguous so that it can work
|
|
# with torch compile.
|
|
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
|
|
|
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
|
# the modality of inputs. For text-only inputs, each dimension has
|
|
# identical position IDs, making M-RoPE functionally equivalent to
|
|
# 1D-RoPE.
|
|
# See page 5 of https://arxiv.org/abs/2409.12191
|
|
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
|
|
dtype=torch.int64,
|
|
device=self.device)
|
|
self.mrope_positions_cpu = torch.zeros(
|
|
(3, self.max_num_tokens + 1),
|
|
dtype=torch.int64,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
|
|
|
|
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
|
self.arange_np: npt.NDArray[np.int32] = np.arange(max(
|
|
self.max_num_reqs + 1, self.model_config.max_model_len,
|
|
self.max_num_tokens),
|
|
dtype=np.int32)
|
|
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
|
# a faster version of creating a new tensor every time. Thus, we should
|
|
# not make any assumptions about the values in these tensors.
|
|
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.positions_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int64,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.positions_np = self.positions_cpu.numpy()
|
|
|
|
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
|
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
|
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
|
|
|
self.use_aclgraph = (self.vllm_config.compilation_config.level
|
|
== CompilationLevel.PIECEWISE
|
|
and not self.model_config.enforce_eager)
|
|
self.aclgraph_batch_sizes = list(
|
|
reversed(
|
|
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
|
|
|
# NOTE: Pre-construct a mask matrix to improve the efficiency of
|
|
# attention mask construction during inference.
|
|
# Note that the length of the matrix needs to be carefully balanced: a
|
|
# matrix that is too large will consume excessive VRAM, while a matrix
|
|
# that is too small will require dynamic concatenation during inference,
|
|
# leading to performance degradation.
|
|
# Therefore, an environment variable is added here to dynamically set
|
|
# the size of the pre-constructed mask matrix based on requirements.
|
|
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
|
|
attn_mask_len = min(self.model_config.max_model_len, int(mask_len))
|
|
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
|
attn_mask_len, self.dtype)
|
|
|
|
self.torchair_compiled_model = None # type: ignore
|
|
self.torchair_compiled_models = {} # type: ignore
|
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
|
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
|
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
|
|
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
|
self.init_torchair_graph_batch_sizes()
|
|
if len(self.torchair_graph_batch_sizes) == 0:
|
|
# TODO(zzzzwwjj): check torchair_graph_batch_sizes init code
|
|
self.torchair_graph_batch_sizes = [self.max_num_reqs]
|
|
|
|
torch._dynamo.cache_size.config.cache_size_limit += len(
|
|
self.torchair_graph_batch_sizes)
|
|
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
|
torch._logging.set_logs(
|
|
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
|
|
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
|
"""Update the cached states and the persistent batch with the scheduler
|
|
output.
|
|
|
|
The SamplingMetadata is updated and copied to the NPU if there is a
|
|
new/resumed/paused/finished request in the batch.
|
|
"""
|
|
# Remove finished requests from the cached states.
|
|
for req_id in scheduler_output.finished_req_ids:
|
|
self.requests.pop(req_id, None)
|
|
self.encoder_cache.pop(req_id, None)
|
|
# Remove the finished requests from the persistent batch.
|
|
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
|
# scheduled_req_ids overlap. This happens when a request is aborted and
|
|
# then resubmitted with the same ID. In this case, we treat them as two
|
|
# distinct requests - clearing the cached states for the first request
|
|
# and handling the second as a new request.
|
|
removed_req_indices: List[int] = []
|
|
for req_id in scheduler_output.finished_req_ids:
|
|
req_index = self.input_batch.remove_request(req_id)
|
|
if req_index is not None:
|
|
removed_req_indices.append(req_index)
|
|
|
|
# Free the cached encoder outputs.
|
|
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
|
encoder_outputs = self.encoder_cache.get(req_id)
|
|
if encoder_outputs is not None:
|
|
encoder_outputs.pop(input_id, None)
|
|
if not encoder_outputs:
|
|
self.encoder_cache.pop(req_id, None)
|
|
|
|
# Remove the unscheduled requests from the persistent batch.
|
|
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
|
# or running requests that are not scheduled in this step. We remove
|
|
# them from the persistent batch but keep their cached states since
|
|
# they will be scheduled again sometime in the future.
|
|
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
|
|
cached_req_ids = self.input_batch.req_id_to_index.keys()
|
|
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
|
|
# NOTE(woosuk): The persistent batch optimization assumes that
|
|
# consecutive batches contain mostly the same requests. If batches
|
|
# have low request overlap (e.g., alternating between two distinct
|
|
# sets of requests), this optimization becomes very inefficient.
|
|
for req_id in unscheduled_req_ids:
|
|
req_index = self.input_batch.remove_request(req_id)
|
|
assert req_index is not None
|
|
removed_req_indices.append(req_index)
|
|
|
|
req_ids_to_add: List[str] = []
|
|
# Add new requests to the cached states.
|
|
for new_req_data in scheduler_output.scheduled_new_reqs:
|
|
req_id = new_req_data.req_id
|
|
sampling_params = new_req_data.sampling_params
|
|
if sampling_params and \
|
|
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
|
generator = torch.Generator(device=self.device)
|
|
generator.manual_seed(sampling_params.seed)
|
|
else:
|
|
generator = None
|
|
|
|
# For vllm v0.9.1 version compatibility, we check if
|
|
# `pooling_params` is present in the new request data.
|
|
pooling_params = getattr(new_req_data, "pooling_params", None)
|
|
self.requests[req_id] = CachedRequestState(
|
|
req_id=req_id,
|
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
|
mm_inputs=new_req_data.mm_inputs,
|
|
mm_positions=new_req_data.mm_positions,
|
|
sampling_params=sampling_params,
|
|
pooling_params=pooling_params,
|
|
generator=generator,
|
|
block_ids=new_req_data.block_ids,
|
|
num_computed_tokens=new_req_data.num_computed_tokens,
|
|
output_token_ids=[],
|
|
lora_request=new_req_data.lora_request,
|
|
)
|
|
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
if self.uses_mrope:
|
|
image_grid_thw = []
|
|
video_grid_thw = []
|
|
second_per_grid_ts = []
|
|
audio_feature_lengths = []
|
|
use_audio_in_video = False
|
|
for mm_input in self.requests[req_id].mm_inputs:
|
|
if mm_input.get("image_grid_thw") is not None:
|
|
image_grid_thw.extend(
|
|
mm_input["image_grid_thw"].tolist())
|
|
if mm_input.get("video_grid_thw") is not None:
|
|
video_grid_thw.extend(
|
|
mm_input["video_grid_thw"].tolist())
|
|
if mm_input.get("second_per_grid_ts") is not None:
|
|
second_per_grid_ts.extend(
|
|
mm_input["second_per_grid_ts"])
|
|
if mm_input.get("audio_feature_lengths") is not None:
|
|
audio_feature_lengths.extend(
|
|
mm_input["audio_feature_lengths"])
|
|
if mm_input.get("use_audio_in_video") is True:
|
|
use_audio_in_video = True
|
|
|
|
hf_config = self.model_config.hf_config
|
|
|
|
self.requests[req_id].mrope_positions, \
|
|
self.requests[req_id].mrope_position_delta = \
|
|
MRotaryEmbedding.get_input_positions_tensor(
|
|
self.requests[req_id].prompt_token_ids,
|
|
hf_config=hf_config,
|
|
image_grid_thw=image_grid_thw,
|
|
video_grid_thw=video_grid_thw,
|
|
second_per_grid_ts=second_per_grid_ts,
|
|
audio_feature_lengths=audio_feature_lengths,
|
|
use_audio_in_video=use_audio_in_video,
|
|
)
|
|
|
|
req_ids_to_add.append(req_id)
|
|
|
|
# Update the states of the running/resumed requests.
|
|
if vllm_version_is("0.9.1"):
|
|
for req_data in scheduler_output.scheduled_cached_reqs:
|
|
req_id = req_data.req_id
|
|
req_state = self.requests[req_id]
|
|
|
|
# Update the cached states.
|
|
num_computed_tokens = req_data.num_computed_tokens
|
|
req_state.num_computed_tokens = num_computed_tokens
|
|
# Add the sampled token(s) from the previous step (if any).
|
|
# This doesn't include "unverified" tokens like spec decode tokens.
|
|
num_new_tokens = (num_computed_tokens +
|
|
len(req_data.new_token_ids) -
|
|
req_state.num_tokens)
|
|
if num_new_tokens == 1:
|
|
# Avoid slicing list in most common case.
|
|
req_state.output_token_ids.append(
|
|
req_data.new_token_ids[-1])
|
|
elif num_new_tokens > 0:
|
|
req_state.output_token_ids.extend(
|
|
req_data.new_token_ids[-num_new_tokens:])
|
|
# Update the block IDs.
|
|
if not req_data.resumed_from_preemption:
|
|
# Append the new blocks to the existing block IDs.
|
|
for block_ids, new_block_ids in zip( # type: ignore[call-overload]
|
|
req_state.block_ids,
|
|
req_data.new_block_ids,
|
|
strict=True):
|
|
block_ids.extend(new_block_ids)
|
|
else:
|
|
# The request is resumed from preemption.
|
|
# Replace the existing block IDs with the new ones.
|
|
req_state.block_ids = req_data.new_block_ids
|
|
|
|
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
if req_index is None:
|
|
# The request is not in the persistent batch.
|
|
# The request was either preempted and resumed later, or was not
|
|
# scheduled in the previous step and needs to be added again.
|
|
req_ids_to_add.append(req_id)
|
|
continue
|
|
|
|
# Update the persistent batch.
|
|
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
|
num_computed_tokens)
|
|
|
|
start_index = (len(req_state.block_ids) -
|
|
len(req_data.new_block_ids))
|
|
self.input_batch.block_table.append_row(
|
|
req_data.new_block_ids, req_index)
|
|
# Add new_token_ids to token_ids_cpu.
|
|
start_token_index = num_computed_tokens
|
|
end_token_index = num_computed_tokens + len(
|
|
req_data.new_token_ids)
|
|
self.input_batch.token_ids_cpu[
|
|
req_index,
|
|
start_token_index:end_token_index] = req_data.new_token_ids
|
|
self.input_batch.num_tokens_no_spec[
|
|
req_index] = end_token_index
|
|
# Add spec_token_ids to token_ids_cpu.
|
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
|
req_id, ())
|
|
if spec_token_ids:
|
|
start_index = end_token_index
|
|
end_token_index += len(spec_token_ids)
|
|
self.input_batch.token_ids_cpu[
|
|
req_index,
|
|
start_index:end_token_index] = spec_token_ids
|
|
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
|
|
self.input_batch.num_tokens[req_index] = end_token_index
|
|
else:
|
|
req_data = scheduler_output.scheduled_cached_reqs
|
|
is_last_rank = get_pp_group().is_last_rank
|
|
for i, req_id in enumerate(req_data.req_ids):
|
|
req_state = self.requests[req_id]
|
|
num_computed_tokens = req_data.num_computed_tokens[i]
|
|
new_block_ids = req_data.new_block_ids[i]
|
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
|
|
|
req_state.num_computed_tokens = num_computed_tokens
|
|
if not is_last_rank:
|
|
new_token_ids = req_data.new_token_ids[i]
|
|
# Add the sampled token(s) from the previous step (if any).
|
|
# This doesn't include "unverified" tokens like spec decode tokens.
|
|
num_new_tokens = (num_computed_tokens +
|
|
len(new_token_ids) -
|
|
req_state.num_tokens)
|
|
if num_new_tokens == 1:
|
|
# Avoid slicing list in most common case.
|
|
req_state.output_token_ids.append(new_token_ids[-1])
|
|
elif num_new_tokens > 0:
|
|
req_state.output_token_ids.extend(
|
|
new_token_ids[-num_new_tokens:])
|
|
# Update the block IDs.
|
|
if not resumed_from_preemption:
|
|
# Append the new blocks to the existing block IDs.
|
|
for block_ids, new_ids in zip( # type: ignore[call-overload]
|
|
req_state.block_ids, new_block_ids):
|
|
block_ids.extend(new_ids)
|
|
else:
|
|
# The request is resumed from preemption.
|
|
# Replace the existing block IDs with the new ones.
|
|
req_state.block_ids = new_block_ids
|
|
|
|
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
if req_index is None:
|
|
# The request is not in the persistent batch.
|
|
# The request was either preempted and resumed later, or was not
|
|
# scheduled in the previous step and needs to be added again.
|
|
req_ids_to_add.append(req_id)
|
|
continue
|
|
|
|
# Update the persistent batch.
|
|
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
|
num_computed_tokens)
|
|
|
|
self.input_batch.block_table.append_row(
|
|
new_block_ids, req_index)
|
|
|
|
if not is_last_rank:
|
|
# Add new_token_ids to token_ids_cpu.
|
|
start_token_index = num_computed_tokens
|
|
end_token_index = num_computed_tokens + len(new_token_ids)
|
|
self.input_batch.token_ids_cpu[
|
|
req_index,
|
|
start_token_index:end_token_index] = new_token_ids
|
|
self.input_batch.num_tokens_no_spec[
|
|
req_index] = end_token_index
|
|
# Add spec_token_ids to token_ids_cpu.
|
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
|
req_id, ())
|
|
if spec_token_ids:
|
|
start_index = end_token_index
|
|
end_token_index += len(spec_token_ids)
|
|
self.input_batch.token_ids_cpu[
|
|
req_index,
|
|
start_index:end_token_index] = spec_token_ids
|
|
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
|
|
self.input_batch.num_tokens[req_index] = end_token_index
|
|
|
|
# Check if the batch has changed. If not, we can skip copying the
|
|
# sampling metadata from CPU to GPU.
|
|
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
|
|
|
|
# Add the new or resumed requests to the persistent batch.
|
|
# The smaller empty indices are filled first.
|
|
removed_req_indices.sort(reverse=True)
|
|
for req_id in req_ids_to_add:
|
|
req_state = self.requests[req_id]
|
|
if removed_req_indices:
|
|
# Fill the empty index.
|
|
req_index = removed_req_indices.pop()
|
|
else:
|
|
# Append to the end.
|
|
req_index = None
|
|
self.input_batch.add_request(req_state, req_index)
|
|
|
|
# Condense the batched states if there are empty indices.
|
|
if removed_req_indices:
|
|
self.input_batch.condense(removed_req_indices)
|
|
|
|
if batch_changed:
|
|
self.input_batch.refresh_sampling_metadata()
|
|
|
|
def _get_forward_metadata_across_dp(
|
|
self, total_num_scheduled_tokens: int,
|
|
with_prefill: bool) -> tuple[int, bool]:
|
|
forward_metadata = torch.tensor(
|
|
[total_num_scheduled_tokens, with_prefill],
|
|
device="cpu",
|
|
dtype=torch.int32)
|
|
dist.all_reduce(forward_metadata,
|
|
op=ReduceOp.MAX,
|
|
group=get_dp_group().cpu_group)
|
|
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
|
|
|
|
def get_eagle_atten_dict(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> dict[str, AscendMetadata]:
|
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
assert total_num_scheduled_tokens > 0
|
|
num_reqs = self.input_batch.num_reqs
|
|
assert num_reqs > 0
|
|
|
|
# OPTIMIZATION: Start copying the block table first.
|
|
# This way, we can overlap the copy with the following CPU operations.
|
|
self.input_batch.block_table.commit(num_reqs)
|
|
|
|
# Get the number of scheduled tokens for each request.
|
|
req_ids = self.input_batch.req_ids
|
|
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
|
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
|
max_num_scheduled_tokens = max(tokens)
|
|
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
|
# Get request indices.
|
|
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
|
req_indices = np.repeat(self.arange_np[:num_reqs],
|
|
num_scheduled_tokens)
|
|
|
|
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
|
|
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
|
num_scheduled_tokens)
|
|
|
|
# Get positions.
|
|
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
|
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
|
arange,
|
|
out=positions_np)
|
|
|
|
# Calculate M-RoPE positions.
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
if self.uses_mrope:
|
|
self._calc_mrope_positions(scheduler_output)
|
|
|
|
# Get token indices.
|
|
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
|
# where M is the max_model_len.
|
|
token_indices = (positions_np +
|
|
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
|
|
|
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
|
# because torch.index_select is much faster than np.take for large
|
|
# tensors.
|
|
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
|
0,
|
|
torch.from_numpy(token_indices),
|
|
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
|
|
|
# Prepare the attention metadata for each KV cache group and make layers
|
|
# in the same group share the same metadata.
|
|
# NOTE(Chen): there is exactly one KV cache group that contains all
|
|
# attetnion layers in the model for now, so the current logic for
|
|
# getting attn_metadata is not related to kv_cache_group information.
|
|
# Will extend this part to support multiple KV cache groups later.
|
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
|
self.kv_cache_config.kv_cache_groups):
|
|
block_size = kv_cache_group_spec.kv_cache_spec.block_size
|
|
block_table = self.input_batch.block_table[kv_cache_group_id]
|
|
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
|
# where K is the max_num_blocks_per_req and the block size is 2.
|
|
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
|
# here because M (max_model_len) is not necessarily divisible by
|
|
# block_size.
|
|
block_table_indices = (
|
|
req_indices * block_table.max_num_blocks_per_req +
|
|
positions_np // block_size)
|
|
block_table_cpu = block_table.get_cpu_tensor()
|
|
block_numbers = block_table_cpu.flatten(
|
|
)[block_table_indices].numpy()
|
|
block_offsets = positions_np % block_size
|
|
np.add(
|
|
block_numbers * block_size,
|
|
block_offsets,
|
|
out=block_table.slot_mapping_np[:total_num_scheduled_tokens])
|
|
|
|
# Prepare the attention metadata.
|
|
self.query_start_loc_np[0] = 0
|
|
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
|
|
|
self.seq_lens_np[:num_reqs] = (
|
|
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
|
num_scheduled_tokens)
|
|
|
|
# Copy the tensors to the NPU.
|
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
|
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
|
if self.uses_mrope:
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
|
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
|
non_blocking=True)
|
|
else:
|
|
# Common case (1D positions)
|
|
self.positions[:total_num_scheduled_tokens].copy_(
|
|
self.positions_cpu[:total_num_scheduled_tokens],
|
|
non_blocking=True)
|
|
|
|
self.query_start_loc[:num_reqs + 1].copy_(
|
|
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
|
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
|
non_blocking=True)
|
|
|
|
# Fill unused with -1. Needed for reshape_and_cache
|
|
self.seq_lens[num_reqs:].fill_(0)
|
|
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
|
|
|
attn_metadata: dict[str, AscendMetadata] = {}
|
|
# Prepare the attention metadata for each KV cache group and make layers
|
|
# in the same group share the same metadata.
|
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
|
self.kv_cache_config.kv_cache_groups):
|
|
|
|
# Prepare for cascade attention if enabled & beneficial.
|
|
common_prefix_len = 0
|
|
|
|
attn_metadata_i = self.attn_metadata_builder.build(
|
|
num_reqs=num_reqs,
|
|
num_actual_tokens=total_num_scheduled_tokens,
|
|
max_query_len=max_num_scheduled_tokens,
|
|
common_prefix_len=common_prefix_len,
|
|
)
|
|
for layer_name in kv_cache_group_spec.layer_names:
|
|
attn_metadata[layer_name] = attn_metadata_i
|
|
|
|
return attn_metadata
|
|
|
|
def get_model(self) -> nn.Module:
|
|
return self.model
|
|
|
|
def _make_attention_mask(self, seq_lens, query_lens, position,
|
|
attn_state) -> torch.Tensor:
|
|
# Chunk Prefill situation.
|
|
if attn_state == AscendAttentionState.ChunkedPrefill:
|
|
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
|
seq_lens, query_lens, position, self.dtype, self.device)
|
|
# Prefill without cache situation.
|
|
elif attn_state == AscendAttentionState.PrefillNoCache:
|
|
max_seq_len = max(seq_lens, default=0)
|
|
return self.attn_mask_builder.get_attn_mask(
|
|
max_seq_len, self.dtype, self.device)
|
|
# Prefill with cache hit.
|
|
elif attn_state == AscendAttentionState.PrefillCacheHit:
|
|
return self.attn_mask_builder.get_attn_mask(
|
|
128, self.dtype, self.device)
|
|
# Decode-only situation.
|
|
else:
|
|
return None
|
|
|
|
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
|
mrope_pos_ptr = 0
|
|
for index, req_id in enumerate(self.input_batch.req_ids):
|
|
req = self.requests[req_id]
|
|
assert req.mrope_positions is not None
|
|
|
|
num_computed_tokens = \
|
|
self.input_batch.num_computed_tokens_cpu[index]
|
|
num_scheduled_tokens = \
|
|
scheduler_output.num_scheduled_tokens[req_id]
|
|
num_prompt_tokens = len(req.prompt_token_ids)
|
|
|
|
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
|
|
prompt_part_len = max(0,
|
|
num_prompt_tokens - num_computed_tokens)
|
|
completion_part_len = max(
|
|
0, num_scheduled_tokens - prompt_part_len)
|
|
else:
|
|
prompt_part_len = num_scheduled_tokens
|
|
completion_part_len = 0
|
|
|
|
assert num_scheduled_tokens == prompt_part_len + completion_part_len
|
|
|
|
if prompt_part_len > 0:
|
|
# prompt's mrope_positions are pre-computed
|
|
dst_start = mrope_pos_ptr
|
|
dst_end = mrope_pos_ptr + prompt_part_len
|
|
src_start = num_computed_tokens
|
|
src_end = num_computed_tokens + prompt_part_len
|
|
|
|
self.mrope_positions_cpu[:, dst_start:dst_end] = \
|
|
req.mrope_positions[:,src_start:src_end]
|
|
|
|
mrope_pos_ptr += prompt_part_len
|
|
|
|
if completion_part_len > 0:
|
|
# compute completion's mrope_positions on-the-fly
|
|
dst_start = mrope_pos_ptr
|
|
dst_end = mrope_pos_ptr + completion_part_len
|
|
|
|
if vllm_version_is("0.9.1"):
|
|
self.mrope_positions_cpu[:, dst_start:dst_end] = \
|
|
MRotaryEmbedding.get_next_input_positions_tensor(
|
|
req.mrope_position_delta,
|
|
context_len=num_computed_tokens +
|
|
prompt_part_len,
|
|
seq_len=num_computed_tokens +
|
|
prompt_part_len +
|
|
completion_part_len,
|
|
)
|
|
else:
|
|
MRotaryEmbedding.get_next_input_positions_tensor(
|
|
out=self.mrope_positions_np,
|
|
out_offset=dst_start,
|
|
mrope_position_delta=req.mrope_position_delta,
|
|
context_len=num_computed_tokens + prompt_part_len,
|
|
num_new_tokens=completion_part_len,
|
|
)
|
|
|
|
mrope_pos_ptr += completion_part_len
|
|
|
|
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
|
|
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
|
if not scheduled_encoder_inputs:
|
|
return
|
|
|
|
# Batch the multi-modal inputs.
|
|
mm_inputs = list[MultiModalKwargs]()
|
|
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
|
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
|
req_state = self.requests[req_id]
|
|
|
|
for mm_input_id in encoder_input_ids:
|
|
mm_inputs.append(req_state.mm_inputs[mm_input_id])
|
|
req_ids_pos.append(
|
|
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
|
|
|
|
# Batch mm inputs as much as we can: if a request in the batch has
|
|
# multiple modalities or a different modality than the previous one,
|
|
# we process it separately to preserve item order.
|
|
# FIXME(ywang96): This is a hacky way to deal with multiple modalities
|
|
# in the same batch while still being able to benefit from batching
|
|
# multimodal inputs. The proper solution should be reordering the
|
|
# encoder outputs.
|
|
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
|
|
|
|
encoder_outputs = []
|
|
for grouped_mm_inputs in grouped_mm_inputs_list:
|
|
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
|
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
|
|
device=self.device)
|
|
|
|
# Run the encoder.
|
|
# `curr_group_outputs` is either of the following:
|
|
# 1. A tensor of shape (num_items, feature_size, hidden_size)
|
|
# in case feature_size is fixed across all multimodal items.
|
|
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
|
# (feature_size, hidden_size) in case the feature size is dynamic
|
|
# depending on the input multimodal items.
|
|
curr_group_outputs = self.model.get_multimodal_embeddings(
|
|
**batched_mm_inputs)
|
|
|
|
sanity_check_mm_encoder_outputs(
|
|
curr_group_outputs,
|
|
expected_num_items=len(grouped_mm_inputs),
|
|
)
|
|
|
|
for output in curr_group_outputs:
|
|
encoder_outputs.append(output)
|
|
|
|
# Cache the encoder outputs.
|
|
for (req_id, input_id, pos_info), output in zip(
|
|
req_ids_pos,
|
|
encoder_outputs,
|
|
):
|
|
if req_id not in self.encoder_cache:
|
|
self.encoder_cache[req_id] = {}
|
|
|
|
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
|
|
output,
|
|
is_embed=pos_info.is_embed,
|
|
)
|
|
|
|
def _gather_mm_embeddings(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> list[torch.Tensor]:
|
|
mm_embeds: list[torch.Tensor] = []
|
|
for req_id in self.input_batch.req_ids:
|
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
|
req_id]
|
|
req_state = self.requests[req_id]
|
|
num_computed_tokens = req_state.num_computed_tokens
|
|
mm_positions = req_state.mm_positions
|
|
for i, pos_info in enumerate(mm_positions):
|
|
start_pos = pos_info.offset
|
|
num_encoder_tokens = pos_info.length
|
|
|
|
# The encoder output is needed if the two ranges overlap:
|
|
# [num_computed_tokens,
|
|
# num_computed_tokens + num_scheduled_tokens) and
|
|
# [start_pos, start_pos + num_encoder_tokens)
|
|
if start_pos >= num_computed_tokens + num_scheduled_tokens:
|
|
# The encoder output is not needed in this step.
|
|
break
|
|
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
|
# The encoder output is already processed and stored
|
|
# in the decoder's KV cache.
|
|
continue
|
|
|
|
start_idx = max(num_computed_tokens - start_pos, 0)
|
|
end_idx = min(
|
|
num_computed_tokens - start_pos + num_scheduled_tokens,
|
|
num_encoder_tokens)
|
|
assert start_idx < end_idx
|
|
assert req_id in self.encoder_cache
|
|
assert i in self.encoder_cache[req_id]
|
|
encoder_output = self.encoder_cache[req_id][i]
|
|
|
|
if (is_embed := pos_info.is_embed) is not None:
|
|
is_embed = is_embed[start_idx:end_idx]
|
|
|
|
mm_embeds_item = gather_mm_placeholders(
|
|
encoder_output[start_idx:end_idx],
|
|
is_embed=is_embed,
|
|
)
|
|
mm_embeds.append(mm_embeds_item)
|
|
return mm_embeds
|
|
|
|
def _process_reqs(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata,
|
|
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray]:
|
|
# Check input valid
|
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
assert total_num_scheduled_tokens > 0
|
|
num_reqs = self.input_batch.num_reqs
|
|
assert num_reqs > 0
|
|
if (self.use_aclgraph and
|
|
total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]):
|
|
# Add padding to the batch size.
|
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
|
total_num_scheduled_tokens)
|
|
else:
|
|
# Eager mode.
|
|
num_input_tokens = total_num_scheduled_tokens
|
|
|
|
modified_batch = self.attn_metadata_builder.reorder_batch(
|
|
self.input_batch, scheduler_output)
|
|
if modified_batch:
|
|
self.input_batch.refresh_sampling_metadata()
|
|
|
|
# OPTIMIZATION: Start copying the block table first.
|
|
# This way, we can overlap the copy with the following CPU operations.
|
|
self.input_batch.block_table.commit(num_reqs)
|
|
|
|
# Get the number of scheduled tokens for each request.
|
|
# TODO: The Python loop can be slow. Optimize.
|
|
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
|
|
num_valid_tokens = np.empty(num_reqs, dtype=np.int32)
|
|
max_num_scheduled_tokens = 0
|
|
for i, req_id in enumerate(self.input_batch.req_ids):
|
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
|
num_scheduled_tokens[i] = num_tokens
|
|
num_valid_tokens[i] = num_tokens - \
|
|
len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
|
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
|
num_tokens)
|
|
|
|
# Hot-Swap lora model
|
|
if self.lora_config:
|
|
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
|
|
|
# Prepare positions
|
|
req_indices = np.repeat(self.arange_np[:num_reqs],
|
|
num_scheduled_tokens)
|
|
cu_num_tokens = np.cumsum(num_scheduled_tokens)
|
|
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
|
|
num_scheduled_tokens)
|
|
sample_indices = cu_num_tokens - 1
|
|
sample_indices = torch.from_numpy(sample_indices).to(self.device,
|
|
non_blocking=True)
|
|
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
|
|
|
|
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
|
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
|
arange,
|
|
out=positions_np)
|
|
|
|
# Calculate M-RoPE positions.
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
if self.uses_mrope:
|
|
self._calc_mrope_positions(scheduler_output)
|
|
|
|
if self.uses_mrope:
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
|
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
|
non_blocking=True)
|
|
|
|
self.positions[total_num_scheduled_tokens:num_input_tokens].zero_()
|
|
self.positions[:total_num_scheduled_tokens].copy_(
|
|
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
|
positions = self.positions[:num_input_tokens]
|
|
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
|
|
|
self.seq_lens_np[:num_reqs] = (
|
|
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
|
num_scheduled_tokens)
|
|
seq_lens = self.seq_lens_cpu[:num_reqs]
|
|
|
|
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
|
positions_np // self.block_size)
|
|
|
|
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
|
|
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
|
|
block_offsets = positions_np % self.block_size
|
|
np.add(block_numbers * self.block_size,
|
|
block_offsets,
|
|
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
|
|
|
ascend_config = get_ascend_config()
|
|
use_spec_decode = len(
|
|
scheduler_output.scheduled_spec_decode_tokens) > 0
|
|
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
|
attn_state = AscendAttentionState.PrefillNoCache
|
|
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
|
elif np.all(num_scheduled_tokens == 1):
|
|
attn_state = AscendAttentionState.DecodeOnly
|
|
# Speculative decoding.
|
|
elif np.all(num_valid_tokens == 1):
|
|
if self.use_eagle:
|
|
attn_state = AscendAttentionState.ChunkedPrefill
|
|
else:
|
|
attn_state = AscendAttentionState.SpecDecoding
|
|
# splitfuse
|
|
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
|
|
attn_state = AscendAttentionState.ChunkedPrefill
|
|
else:
|
|
attn_state = AscendAttentionState.PrefillCacheHit
|
|
|
|
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
|
|
query_lens=num_scheduled_tokens,
|
|
position=positions,
|
|
attn_state=attn_state)
|
|
self.attn_mask = attn_mask
|
|
self.attn_state = attn_state # type: ignore
|
|
|
|
extra_builder_kwargs = {}
|
|
|
|
self.query_start_loc_np[0] = 0
|
|
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
|
self.query_start_loc[:num_reqs + 1].copy_(
|
|
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
|
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
|
non_blocking=True)
|
|
|
|
# Fill unused with -1. Needed for reshape_and_cache
|
|
self.seq_lens[num_reqs:].fill_(0)
|
|
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
|
|
|
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
|
seq_lens = self.seq_lens[:num_reqs]
|
|
common_attn_metadata = CommonAttentionMetadata(
|
|
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
|
with_prefill = attn_state not in [
|
|
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
|
]
|
|
|
|
if self.dp_size > 1:
|
|
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
|
|
total_num_scheduled_tokens, with_prefill)
|
|
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
|
|
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
|
|
|
|
# Add graph_pad_size here
|
|
if self.torchair_graph_enabled and not with_prefill:
|
|
if self.dp_size > 1:
|
|
padded_batch_size = self.select_torchair_padded_batch_size(
|
|
max_num_tokens)
|
|
else:
|
|
padded_batch_size = self.select_torchair_padded_batch_size(
|
|
total_num_scheduled_tokens)
|
|
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
|
|
|
|
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
|
|
|
if self.vllm_config.model_config.use_mla:
|
|
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
|
num_reqs=num_reqs,
|
|
num_actual_tokens=total_num_scheduled_tokens,
|
|
max_query_len=max_num_scheduled_tokens,
|
|
common_attn_metadata=common_attn_metadata,
|
|
common_prefix_len=None,
|
|
**extra_builder_kwargs,
|
|
)
|
|
else:
|
|
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
|
num_reqs=num_reqs,
|
|
num_actual_tokens=total_num_scheduled_tokens,
|
|
max_query_len=max_num_scheduled_tokens,
|
|
common_prefix_len=None,
|
|
**extra_builder_kwargs,
|
|
)
|
|
attn_metadata.num_input_tokens = num_input_tokens
|
|
|
|
# Prepare input_ids
|
|
token_indices = (positions_np +
|
|
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
|
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
|
0,
|
|
torch.from_numpy(token_indices),
|
|
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
|
# Copy the tensors to the NPU.
|
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
|
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
|
|
|
# _prepare_inputs may reorder the batch, so we must gather multi
|
|
# modal outputs after that to ensure the correct order
|
|
if self.is_multimodal_model:
|
|
# Run the multimodal encoder if any.
|
|
self._execute_mm_encoder(scheduler_output)
|
|
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
|
else:
|
|
mm_embeds = []
|
|
|
|
if self.is_multimodal_model:
|
|
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
|
# embeddings), we always use embeddings (rather than token ids)
|
|
# as input to the multimodal model, even when the input is text.
|
|
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
|
if mm_embeds:
|
|
inputs_embeds = self.model.get_input_embeddings(
|
|
input_ids, mm_embeds)
|
|
else:
|
|
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
|
# TODO(woosuk): Avoid the copy. Optimize.
|
|
self.inputs_embeds[:total_num_scheduled_tokens].copy_(
|
|
inputs_embeds)
|
|
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
|
input_ids = None
|
|
else:
|
|
# For text-only models, we use token ids as input.
|
|
# While it is possible to use embeddings as input just like the
|
|
# multimodal models, it is not desirable for performance since
|
|
# then the embedding layer is not included in the ACL graph.
|
|
input_ids = self.input_ids[:num_input_tokens]
|
|
inputs_embeds = None
|
|
if self.uses_mrope:
|
|
positions = self.mrope_positions[:, :num_input_tokens]
|
|
|
|
if self.torchair_graph_enabled and not with_prefill:
|
|
input_ids = self.input_ids[:padded_batch_size]
|
|
positions = self.positions[:padded_batch_size]
|
|
|
|
# Run forward pass
|
|
with set_forward_context(attn_metadata,
|
|
self.vllm_config,
|
|
num_tokens=num_input_tokens):
|
|
with ProfileExecuteDuration().capture_async("forward"):
|
|
model_kwargs = {}
|
|
if self.torchair_graph_enabled:
|
|
model_kwargs["kv_caches"] = self.kv_caches
|
|
model_kwargs["attn_metadata"] = attn_metadata
|
|
if self.torchair_graph_enabled and not with_prefill:
|
|
maybe_converting_weight_acl_format(self.model,
|
|
ACL_FORMAT_FRACTAL_NZ)
|
|
|
|
compiled_model = self._get_torchair_lazy_compiled_model(
|
|
padded_batch_size)
|
|
hidden_states = compiled_model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
**model_kwargs,
|
|
)
|
|
else:
|
|
assert self.model is not None
|
|
maybe_converting_weight_acl_format(self.model,
|
|
ACL_FORMAT_FRACTAL_ND)
|
|
|
|
hidden_states = self.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
**model_kwargs,
|
|
)
|
|
|
|
use_spec_decode = len(
|
|
scheduler_output.scheduled_spec_decode_tokens) > 0
|
|
if not use_spec_decode:
|
|
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
|
# partial requests. While we should not sample any token
|
|
# from these partial requests, we do so for simplicity.
|
|
# We will ignore the sampled tokens from the partial requests.
|
|
# TODO: Support prompt logprobs.
|
|
spec_decode_metadata = None
|
|
else:
|
|
# Get the number of draft tokens for each request.
|
|
# Iterate over the dictionary rather than all requests since not all
|
|
# requests have draft tokens.
|
|
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
|
for req_id, draft_token_ids in (
|
|
scheduler_output.scheduled_spec_decode_tokens.items()):
|
|
req_idx = self.input_batch.req_id_to_index[req_id]
|
|
num_draft_tokens[req_idx] = len(draft_token_ids)
|
|
|
|
spec_decode_metadata = self._calc_spec_decode_metadata(
|
|
num_draft_tokens, cu_num_tokens)
|
|
sample_indices = spec_decode_metadata.logits_indices
|
|
|
|
aux_hidden_states = None
|
|
if self.use_aux_hidden_state_outputs:
|
|
hidden_states, aux_hidden_states = hidden_states
|
|
|
|
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
|
|
total_num_scheduled_tokens, sample_indices, aux_hidden_states,
|
|
num_scheduled_tokens)
|
|
|
|
def _get_cumsum_and_arange(
|
|
self,
|
|
num_tokens: np.ndarray,
|
|
cumsum_dtype: Optional[np.dtype] = None,
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
"""Get the cumulative sum and batched arange of the given array.
|
|
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
|
|
# Equivalent to but faster than:
|
|
# np.concatenate([np.arange(n) for n in num_tokens])
|
|
"""
|
|
# Step 1. [2, 5, 3] -> [2, 7, 10]
|
|
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
|
|
total_num_tokens = cu_num_tokens[-1]
|
|
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
|
|
cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
|
|
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
arange = self.arange_np[:total_num_tokens] - cumsums_offsets
|
|
|
|
return cu_num_tokens, arange
|
|
|
|
def _calc_spec_decode_metadata(
|
|
self,
|
|
num_draft_tokens: np.ndarray,
|
|
cu_num_scheduled_tokens: np.ndarray,
|
|
) -> SpecDecodeMetadata:
|
|
# Inputs:
|
|
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
|
|
# num_draft_tokens: [ 3, 0, 2, 0, 1]
|
|
# Outputs:
|
|
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
|
|
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
|
|
# 206, 207, 208]
|
|
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
|
|
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
|
|
|
|
# Compute the logits indices.
|
|
# [4, 1, 3, 1, 2]
|
|
num_sampled_tokens = num_draft_tokens + 1
|
|
# Step 1. [4, 5, 8, 9, 11]
|
|
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
|
|
total_num_sampled_tokens = cu_num_sampled_tokens[-1]
|
|
# Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
|
cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
|
|
num_sampled_tokens)
|
|
# Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
|
arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
|
|
# Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
|
logits_indices = np.repeat(
|
|
cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
|
|
# Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
|
logits_indices += arange
|
|
|
|
# Compute the bonus logits indices.
|
|
bonus_logits_indices = cu_num_sampled_tokens - 1
|
|
|
|
# Compute the draft logits indices.
|
|
# [3, 3, 5, 5, 6]
|
|
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
|
|
total_num_draft_tokens = cu_num_draft_tokens[-1]
|
|
# [0, 0, 0, 3, 3, 5]
|
|
cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
|
|
num_draft_tokens)
|
|
# [0, 1, 2, 0, 1, 0]
|
|
arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
|
|
# [0, 0, 0, 5, 5, 9]
|
|
target_logits_indices = np.repeat(
|
|
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
|
|
# [0, 1, 2, 5, 6, 9]
|
|
target_logits_indices += arange
|
|
|
|
# TODO: Optimize the CPU -> NPU copy.
|
|
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
|
|
self.device, non_blocking=True)
|
|
logits_indices = torch.from_numpy(logits_indices).to(self.device,
|
|
non_blocking=True)
|
|
target_logits_indices = torch.from_numpy(target_logits_indices).to(
|
|
self.device, non_blocking=True)
|
|
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
|
|
self.device, non_blocking=True)
|
|
|
|
# Compute the draft token ids.
|
|
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
|
|
draft_token_ids = self.input_ids[logits_indices]
|
|
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
|
|
|
metadata = SpecDecodeMetadata(
|
|
draft_token_ids=draft_token_ids,
|
|
num_draft_tokens=num_draft_tokens.tolist(),
|
|
cu_num_draft_tokens=cu_num_draft_tokens,
|
|
target_logits_indices=target_logits_indices,
|
|
bonus_logits_indices=bonus_logits_indices,
|
|
logits_indices=logits_indices,
|
|
)
|
|
return metadata
|
|
|
|
def apply_grammar_bitmask(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
logits: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# Serialization of np.ndarray is much more efficient than a tensor,
|
|
# so we receive it in that format.
|
|
grammar_bitmask = scheduler_output.grammar_bitmask
|
|
|
|
# We receive the structured output bitmask from the scheduler, but the
|
|
# indices of the requests in the batch may not match the indices of
|
|
# the bitmask since the scheduler doesn't know how the gpu runner is
|
|
# ordering the requests in the batch. We need to sort the bitmask to
|
|
# match the order of the requests used here.
|
|
struct_out_req_batch_indices: dict[str, int] = {}
|
|
indices_match = True
|
|
for req_id in self.input_batch.req_ids:
|
|
mask_index = scheduler_output.structured_output_request_ids.get(
|
|
req_id)
|
|
if mask_index is None:
|
|
# not a structured output request
|
|
continue
|
|
batch_index = self.input_batch.req_id_to_index[req_id]
|
|
if batch_index != mask_index:
|
|
indices_match = False
|
|
struct_out_req_batch_indices[req_id] = batch_index
|
|
|
|
if not indices_match:
|
|
# Sort the bitmask to match the order of the requests
|
|
sorted_bitmask = np.zeros_like(grammar_bitmask)
|
|
for req_id, batch_index in struct_out_req_batch_indices.items():
|
|
orig_index = scheduler_output.structured_output_request_ids[
|
|
req_id]
|
|
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
|
|
grammar_bitmask = sorted_bitmask
|
|
|
|
grammar_bitmask = torch.from_numpy(grammar_bitmask)
|
|
|
|
# TODO: compatibility with spec decode.
|
|
# NOTE:
|
|
# 1. XGrammar bitmask applying only supports CPU and GPU.
|
|
# 2. The logits and bitmask should be on the same device.
|
|
# 3. XGrammar logits on CPU only supports float32 dtype.
|
|
logits_dtype = logits.dtype
|
|
logits = logits.to("cpu").float()
|
|
xgr.apply_token_bitmask_inplace(
|
|
logits,
|
|
grammar_bitmask,
|
|
indices=list(struct_out_req_batch_indices.values()),
|
|
)
|
|
return logits.to(self.device).to(logits_dtype)
|
|
|
|
def _get_spec_token_ids(
|
|
self,
|
|
valid_sampled_token_ids: list[list[int]],
|
|
sampling_metadata: SamplingMetadata,
|
|
scheduler_output: "SchedulerOutput",
|
|
spec_decode_metadata: SpecDecodeMetadata,
|
|
positions: torch.Tensor,
|
|
num_scheduled_tokens: int,
|
|
hidden_states: torch.Tensor,
|
|
attn_metadata: SpecDecodeMetadata,
|
|
aux_hidden_states: torch.Tensor = None,
|
|
) -> Optional[list[list[int]]]:
|
|
if not self.use_spec_decode:
|
|
# Speculative decoding is not enabled.
|
|
spec_token_ids = None
|
|
elif self.speculative_config.method == "ngram":
|
|
assert isinstance(self.drafter, NgramProposer)
|
|
spec_token_ids = self._generate_draft_token_ids(
|
|
valid_sampled_token_ids, sampling_metadata)
|
|
elif self.speculative_config.method == "eagle":
|
|
raise NotImplementedError("Eagle Is Not Supported Yet.")
|
|
elif self.speculative_config.method == "eagle3":
|
|
assert isinstance(self.drafter, EagleProposer)
|
|
if self.speculative_config.use_eagle():
|
|
next_token_ids: list[int] = []
|
|
for i, token_ids in enumerate(valid_sampled_token_ids):
|
|
if token_ids:
|
|
# Common case.
|
|
next_token_id = token_ids[-1]
|
|
else:
|
|
# Partial prefill (rare case).
|
|
# Get the next token id from the request state.
|
|
req_id = self.input_batch.req_ids[i]
|
|
req_state = self.requests[req_id]
|
|
seq_len = (
|
|
req_state.num_computed_tokens +
|
|
scheduler_output.num_scheduled_tokens[req_id])
|
|
|
|
next_token_id = req_state.get_token_id(seq_len)
|
|
next_token_ids.append(next_token_id)
|
|
next_token_ids = torch.tensor(next_token_ids,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
eagle_attn_metadata = attn_metadata[
|
|
self.drafter.attn_layer_name]
|
|
num_input_tokens = scheduler_output.total_num_scheduled_tokens
|
|
if spec_decode_metadata is None:
|
|
# input_ids can be None for multimodal models.
|
|
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
|
target_positions = positions[:num_scheduled_tokens]
|
|
if self.use_aux_hidden_state_outputs:
|
|
target_hidden_states = torch.cat([
|
|
h[:num_scheduled_tokens] for h in aux_hidden_states
|
|
],
|
|
dim=-1)
|
|
else:
|
|
target_hidden_states = hidden_states[:
|
|
num_scheduled_tokens]
|
|
target_slot_mapping = eagle_attn_metadata.slot_mapping
|
|
cu_num_tokens = eagle_attn_metadata.query_start_loc
|
|
else:
|
|
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
|
num_rejected_tokens = [
|
|
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
|
for i, n in enumerate(num_draft_tokens)
|
|
]
|
|
num_rejected_tokens = torch.tensor(
|
|
num_rejected_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
num_tokens = num_scheduled_tokens - sum(
|
|
num_rejected_tokens)
|
|
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
|
eagle_attn_metadata.query_start_loc,
|
|
num_rejected_tokens, num_tokens)
|
|
target_token_ids = self.input_ids[token_indices]
|
|
target_positions = positions[token_indices]
|
|
if self.use_aux_hidden_state_outputs:
|
|
target_hidden_states = torch.cat(
|
|
[h[token_indices] for h in aux_hidden_states],
|
|
dim=-1)
|
|
else:
|
|
target_hidden_states = hidden_states[token_indices]
|
|
target_slot_mapping = eagle_attn_metadata.slot_mapping[
|
|
token_indices]
|
|
|
|
positions = self.positions[:num_input_tokens]
|
|
draft_token_ids = self.drafter.propose(
|
|
target_token_ids=target_token_ids,
|
|
target_positions=target_positions,
|
|
target_hidden_states=target_hidden_states,
|
|
target_slot_mapping=target_slot_mapping,
|
|
next_token_ids=next_token_ids,
|
|
cu_num_tokens=cu_num_tokens,
|
|
block_table=eagle_attn_metadata.block_tables,
|
|
sampling_metadata=sampling_metadata,
|
|
)
|
|
spec_token_ids = draft_token_ids.tolist()
|
|
elif self.speculative_config.method == 'deepseek_mtp':
|
|
assert isinstance(self.drafter, MtpProposer)
|
|
spec_token_ids = self._generate_mtp_token_ids(
|
|
valid_sampled_token_ids, sampling_metadata, scheduler_output,
|
|
spec_decode_metadata, positions, num_scheduled_tokens,
|
|
hidden_states, attn_metadata)
|
|
return spec_token_ids
|
|
|
|
def _pool(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
num_scheduled_tokens: int,
|
|
num_scheduled_tokens_np: np.ndarray,
|
|
) -> ModelRunnerOutput:
|
|
assert self.input_batch.num_reqs ==\
|
|
len(self.input_batch.pooling_params), \
|
|
"Either all or none of the requests in" \
|
|
" a batch must be pooling request"
|
|
|
|
extracted_hidden_states = list(
|
|
torch.split(hidden_states[:num_scheduled_tokens],
|
|
num_scheduled_tokens_np.tolist()))
|
|
|
|
pooling_metadata = self.input_batch.pooling_metadata
|
|
|
|
raw_pooler_output = self.model.pooler(
|
|
hidden_states=extracted_hidden_states,
|
|
pooling_metadata=pooling_metadata)
|
|
|
|
pooler_output: list[Optional[torch.Tensor]] = []
|
|
seq_lens = self.seq_lens[:self.input_batch.num_reqs]
|
|
for raw_output, seq_len, prompt_len in zip(
|
|
raw_pooler_output, seq_lens, pooling_metadata.prompt_lens):
|
|
|
|
if seq_len == prompt_len:
|
|
pooler_output.append(raw_output.data.cpu())
|
|
else:
|
|
pooler_output.append(None)
|
|
|
|
return ModelRunnerOutput(
|
|
req_ids=self.input_batch.req_ids,
|
|
req_id_to_index=self.input_batch.req_id_to_index,
|
|
sampled_token_ids=[],
|
|
spec_token_ids=None,
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=pooler_output,
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def execute_model(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
|
with ProfileExecuteDuration().capture_async(
|
|
"prepare input and forward"):
|
|
self._update_states(scheduler_output)
|
|
if not scheduler_output.total_num_scheduled_tokens:
|
|
# Return empty ModelRunnerOuptut if there's no work to do.
|
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
|
(attn_metadata, hidden_states, spec_decode_metadata, positions,
|
|
num_scheduled_tokens, sample_indices, aux_hidden_states,
|
|
num_scheduled_tokens_np) = (self._process_reqs(
|
|
scheduler_output, intermediate_tensors))
|
|
|
|
with ProfileExecuteDuration().capture_async("post process"):
|
|
|
|
if self.input_batch.pooling_params:
|
|
return self._pool(hidden_states, num_scheduled_tokens,
|
|
num_scheduled_tokens_np)
|
|
logits = self.model.compute_logits(hidden_states[sample_indices],
|
|
None)
|
|
if self.use_eagle:
|
|
attn_metadata = self.get_eagle_atten_dict(scheduler_output)
|
|
# Apply structured output bitmasks if present
|
|
if scheduler_output.grammar_bitmask is not None:
|
|
logits = self.apply_grammar_bitmask(scheduler_output, logits)
|
|
|
|
# Sample the next token and get logprobs if needed.
|
|
sampling_metadata = self.input_batch.sampling_metadata
|
|
if spec_decode_metadata is None:
|
|
sampler_output = self.sampler(
|
|
logits=logits,
|
|
sampling_metadata=sampling_metadata,
|
|
)
|
|
else:
|
|
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
|
# creates a new tensor with separate storage from the original
|
|
# logits tensor. This means any in-place operations on bonus_logits
|
|
# won't affect the original logits tensor.
|
|
bonus_logits = logits[
|
|
spec_decode_metadata.bonus_logits_indices]
|
|
sampler_output = self.sampler(
|
|
logits=bonus_logits,
|
|
sampling_metadata=sampling_metadata,
|
|
)
|
|
bonus_token_ids = sampler_output.sampled_token_ids
|
|
|
|
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
|
# separate storage from the original `logits` tensor. Therefore,
|
|
# it is safe to update `target_logits` in place.
|
|
target_logits = logits[
|
|
spec_decode_metadata.target_logits_indices]
|
|
output_token_ids = self.rejection_sampler(
|
|
spec_decode_metadata,
|
|
None, # draft_probs
|
|
target_logits,
|
|
bonus_token_ids,
|
|
sampling_metadata,
|
|
)
|
|
sampler_output.sampled_token_ids = output_token_ids
|
|
|
|
discard_sampled_tokens_req_indices: list[int] = []
|
|
# TODO(woosuk): The following loop can be slow since it iterates over
|
|
# the requests one by one. Optimize.
|
|
discard_sampled_tokens_req_indices = []
|
|
for i, req_id in enumerate(self.input_batch.req_ids):
|
|
req_state = self.requests[req_id]
|
|
seq_len = (req_state.num_computed_tokens +
|
|
scheduler_output.num_scheduled_tokens[req_id])
|
|
if seq_len < req_state.num_tokens:
|
|
# Ignore the sampled token.
|
|
# Rewind the generator state as if the token was not sampled.
|
|
generator = self.input_batch.generators.get(i)
|
|
if generator is not None:
|
|
generator.set_offset(generator.get_offset() - 4)
|
|
discard_sampled_tokens_req_indices.append(i)
|
|
|
|
# NOTE: NPU -> CPU Sync happens here.
|
|
# Move as many CPU operations as possible before this sync point.
|
|
logprobs_tensors = sampler_output.logprobs_tensors
|
|
logprobs_lists = logprobs_tensors.tolists() \
|
|
if logprobs_tensors is not None else None
|
|
|
|
# Compute prompt logprobs if needed.
|
|
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
|
hidden_states[:num_scheduled_tokens],
|
|
scheduler_output,
|
|
)
|
|
|
|
# Get the valid generated tokens.
|
|
sampled_token_ids = sampler_output.sampled_token_ids
|
|
max_gen_len = sampled_token_ids.shape[-1]
|
|
if max_gen_len == 1:
|
|
# No spec decode tokens.
|
|
valid_sampled_token_ids = sampled_token_ids.tolist()
|
|
else:
|
|
# Includes spec decode tokens.
|
|
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
|
sampled_token_ids,
|
|
self.input_batch.vocab_size,
|
|
)
|
|
|
|
for i in discard_sampled_tokens_req_indices:
|
|
valid_sampled_token_ids[i].clear()
|
|
if not vllm_version_is("0.9.1"):
|
|
# Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions
|
|
# doesn't need to send them back.
|
|
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
|
# the sampled tokens back, because there's no direct communication
|
|
# between the first-stage worker and the last-stage worker.
|
|
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
|
|
if not sampled_ids:
|
|
continue
|
|
|
|
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
|
end_idx = start_idx + len(sampled_ids)
|
|
assert end_idx <= self.model_config.max_model_len, (
|
|
"Sampled token IDs exceed the max model length. "
|
|
f"Total number of tokens: {end_idx} > max_model_len: "
|
|
f"{self.model_config.max_model_len}")
|
|
|
|
self.input_batch.token_ids_cpu[
|
|
req_idx, start_idx:end_idx] = sampled_ids
|
|
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
|
self.input_batch.num_tokens[req_idx] = end_idx
|
|
req_id = self.input_batch.req_ids[req_idx]
|
|
req_state = self.requests[req_id]
|
|
req_state.output_token_ids.extend(sampled_ids)
|
|
|
|
spec_token_ids = self._get_spec_token_ids(
|
|
valid_sampled_token_ids,
|
|
sampling_metadata,
|
|
scheduler_output,
|
|
spec_decode_metadata,
|
|
positions,
|
|
num_scheduled_tokens,
|
|
hidden_states,
|
|
attn_metadata,
|
|
aux_hidden_states,
|
|
)
|
|
if vllm_version_is("0.9.1"):
|
|
model_runner_output = ModelRunnerOutput(
|
|
req_ids=self.input_batch.req_ids,
|
|
req_id_to_index=self.input_batch.req_id_to_index,
|
|
sampled_token_ids=valid_sampled_token_ids,
|
|
spec_token_ids=spec_token_ids,
|
|
logprobs=logprobs_lists,
|
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
|
)
|
|
else:
|
|
model_runner_output = ModelRunnerOutput(
|
|
req_ids=self.input_batch.req_ids,
|
|
req_id_to_index=self.input_batch.req_id_to_index,
|
|
sampled_token_ids=valid_sampled_token_ids,
|
|
spec_token_ids=spec_token_ids,
|
|
logprobs=logprobs_lists,
|
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
|
pooler_output=[],
|
|
)
|
|
|
|
durations = ProfileExecuteDuration().pop_captured_sync()
|
|
if durations:
|
|
dr_str = [
|
|
f"[{tag}]:{duration:.2f}ms"
|
|
for tag, duration in durations.items()
|
|
]
|
|
captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill"
|
|
logger.info("Profile execute duration [%s]:%s", captured_name,
|
|
" ".join(dr_str))
|
|
|
|
return model_runner_output
|
|
|
|
def _profile_multimodal(self) -> None:
|
|
# TODO: handle encoder-decoder models once we support them.
|
|
# NOTE: Currently model is profiled with a single non-text
|
|
# modality with the max possible input tokens even when
|
|
# it supports multiple.
|
|
|
|
if (not self.is_multimodal_model
|
|
or self.max_num_encoder_input_tokens <= 0
|
|
or self.encoder_cache_size <= 0):
|
|
return
|
|
|
|
max_tokens_by_modality_dict = (
|
|
MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality(
|
|
self.model_config))
|
|
dummy_data_modality, max_tokens_per_mm_item = max(
|
|
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
|
|
|
|
# Check how many items of this modality can be supported by
|
|
# the encoder budget.
|
|
encoder_budget = min(self.max_num_encoder_input_tokens,
|
|
self.encoder_cache_size)
|
|
|
|
max_num_mm_items_encoder_budget = cdiv(encoder_budget,
|
|
max_tokens_per_mm_item)
|
|
|
|
# Check how many items of this modality can be supported by
|
|
# the decoder budget.
|
|
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
|
|
self.model_config)[dummy_data_modality]
|
|
|
|
# NOTE: We do not consider max_num_batched_tokens on purpose
|
|
# because the multimodal embeddings can be generated in advance
|
|
# and chunked prefilled.
|
|
max_num_mm_items_decoder_budget = self.max_num_reqs * \
|
|
max_mm_items_per_req
|
|
|
|
max_num_mm_items = min(max_num_mm_items_encoder_budget,
|
|
max_num_mm_items_decoder_budget)
|
|
|
|
logger.info(
|
|
"Encoder cache will be initialized with a budget of %s tokens,"
|
|
" and profiled with %s %s items of the maximum feature size.",
|
|
encoder_budget, max_num_mm_items, dummy_data_modality)
|
|
|
|
# Create dummy batch of multimodal inputs.
|
|
dummy_request_data = self.input_registry.dummy_data_for_profiling(
|
|
model_config=self.model_config,
|
|
seq_len=self.max_num_tokens,
|
|
mm_registry=self.mm_registry,
|
|
)
|
|
dummy_mm_data = dummy_request_data.multi_modal_data
|
|
|
|
if not isinstance(dummy_mm_data, MultiModalKwargs):
|
|
# TODO: Delete this check once input mapper is fully removed.
|
|
raise RuntimeError("Legacy input mapper is not supported in V1")
|
|
|
|
# Dummy data definition in V0 may contain multiple multimodal items
|
|
# (e.g, multiple images) for a single request, therefore here we
|
|
# always replicate first item by max_num_mm_items times since in V1
|
|
# they are scheduled to be processed separately.
|
|
|
|
dummy_mm_item = dummy_mm_data.get_item(modality=dummy_data_modality,
|
|
item_index=0)
|
|
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
|
|
|
|
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
|
|
max_num_mm_items)
|
|
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
|
batched_dummy_mm_inputs, device=self.device)
|
|
|
|
# Run multimodal encoder.
|
|
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
|
**batched_dummy_mm_inputs)
|
|
assert len(dummy_encoder_outputs) == max_num_mm_items, (
|
|
"Expected dimension 0 of encoder outputs to match the number "
|
|
f"of multimodal data items: {max_num_mm_items}, got "
|
|
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
|
|
"due to the 'get_multimodal_embeddings' method of the model "
|
|
"not implemented correctly.")
|
|
|
|
# Cache the dummy encoder outputs.
|
|
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
|
|
|
@torch.inference_mode()
|
|
def _dummy_run(
|
|
self,
|
|
num_tokens: int,
|
|
is_compile: bool = False,
|
|
with_prefill: bool = True,
|
|
skip_attn: bool = True,
|
|
) -> torch.Tensor:
|
|
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
|
# for dummy run with LoRA so that the num_reqs collectively
|
|
# has num_tokens in total.
|
|
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
|
num_reqs = self.max_num_reqs if num_tokens >= self.max_num_reqs else num_tokens
|
|
min_tokens_per_req = num_tokens // num_reqs
|
|
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
|
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
|
assert sum(num_scheduled_tokens_list) == num_tokens
|
|
assert len(num_scheduled_tokens_list) == num_reqs
|
|
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
|
dtype=np.int32)
|
|
if skip_attn:
|
|
attn_metadata = None
|
|
else:
|
|
attn_metadata = self.attn_metadata_builder.build(
|
|
num_reqs=num_tokens,
|
|
num_actual_tokens=num_tokens,
|
|
max_query_len=num_tokens,
|
|
common_prefix_len=0,
|
|
)
|
|
|
|
with self.maybe_dummy_run_with_lora(self.lora_config,
|
|
num_scheduled_tokens):
|
|
model = self.model
|
|
if self.is_multimodal_model:
|
|
input_ids = None
|
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
|
else:
|
|
input_ids = self.input_ids[:num_tokens]
|
|
inputs_embeds = None
|
|
|
|
if self.uses_mrope:
|
|
positions = self.mrope_positions[:, :num_tokens]
|
|
else:
|
|
positions = self.positions[:num_tokens]
|
|
|
|
if get_pp_group().is_first_rank:
|
|
intermediate_tensors = None
|
|
else:
|
|
if self.intermediate_tensors is None:
|
|
self.intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors(
|
|
batch_size=num_tokens,
|
|
dtype=self.dtype,
|
|
device=self.device))
|
|
intermediate_tensors = IntermediateTensors({
|
|
k: v[:num_tokens]
|
|
for k, v in self.intermediate_tensors.items()
|
|
})
|
|
|
|
with set_forward_context(None,
|
|
self.vllm_config,
|
|
num_tokens=num_tokens):
|
|
if self.torchair_graph_enabled and not with_prefill:
|
|
attn_metadata = self.attn_metadata_builder.build_dummy(
|
|
num_reqs=num_tokens, num_actual_tokens=1)
|
|
# Only mark static while compiling
|
|
if is_compile:
|
|
torch._dynamo.mark_static(input_ids)
|
|
torch._dynamo.mark_static(positions)
|
|
torch._dynamo.mark_static(
|
|
attn_metadata.decode.block_table)
|
|
torch._dynamo.mark_static(
|
|
attn_metadata.decode.input_positions)
|
|
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
|
for kv in self.kv_caches:
|
|
assert isinstance(
|
|
kv, tuple), "kv_cache must be a tuple"
|
|
torch._dynamo.mark_static(kv[0])
|
|
torch._dynamo.mark_static(kv[1])
|
|
|
|
maybe_converting_weight_acl_format(self.model,
|
|
ACL_FORMAT_FRACTAL_NZ)
|
|
|
|
compiled_model = self._get_torchair_lazy_compiled_model(
|
|
num_tokens)
|
|
hidden_states = compiled_model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=None,
|
|
kv_caches=self.kv_caches,
|
|
attn_metadata=attn_metadata,
|
|
)
|
|
else:
|
|
maybe_converting_weight_acl_format(self.model,
|
|
ACL_FORMAT_FRACTAL_ND)
|
|
|
|
hidden_states = model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds)
|
|
if self.use_aux_hidden_state_outputs:
|
|
hidden_states, _ = hidden_states
|
|
else:
|
|
hidden_states = hidden_states
|
|
if self.use_spec_decode and \
|
|
self.speculative_config.method in ('eagle', 'eagle3'):
|
|
assert isinstance(self.drafter, EagleProposer)
|
|
self.drafter.dummy_run(num_tokens)
|
|
return hidden_states
|
|
|
|
def profile_run(self) -> None:
|
|
# FIXME Profile with multimodal encoder & encoder cache.
|
|
# current _profile_multimodal() using PyTorch SDPA backend method not
|
|
# support for window/full attn to reduce Memcpy operations, so will cause
|
|
# Out Of Memory problem, so we currently don't use self._profile_multimodal()
|
|
# self._profile_multimodal()
|
|
|
|
# For profile, have maximum num_reqs and that collectively have
|
|
# maximum num_tokens.
|
|
min_tokens_per_req = self.max_num_tokens // self.max_num_reqs
|
|
|
|
num_scheduled_tokens_list = [min_tokens_per_req] * self.max_num_reqs
|
|
num_scheduled_tokens_list[
|
|
-1] += self.max_num_tokens % self.max_num_reqs
|
|
assert sum(num_scheduled_tokens_list) == self.max_num_tokens
|
|
assert len(num_scheduled_tokens_list) == self.max_num_reqs
|
|
|
|
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
|
dtype=np.int32)
|
|
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
|
|
|
# assert self.lora_manager is not None, "LoRA is not enabled"
|
|
# TODO: call maybe_profile_with_lora()
|
|
|
|
# Trigger compilation for general shape.
|
|
hidden_states = self._dummy_run(self.max_num_tokens)
|
|
|
|
if get_pp_group().is_last_rank:
|
|
if self.is_pooling_model:
|
|
output = self._dummy_pooler_run(hidden_states)
|
|
else:
|
|
# TODO: need to rum a dummy sampler for generate task
|
|
hidden_states = hidden_states[logit_indices]
|
|
output = self.model.compute_logits(hidden_states, None)
|
|
else:
|
|
output = None
|
|
|
|
NPUPlatform.synchronize()
|
|
del hidden_states, output
|
|
self.encoder_cache.clear()
|
|
gc.collect()
|
|
|
|
@torch.inference_mode()
|
|
def _dummy_pooler_run(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
|
|
num_tokens = hidden_states.shape[0]
|
|
max_num_reqs = self.scheduler_config.max_num_seqs
|
|
num_reqs = min(num_tokens, max_num_reqs)
|
|
min_tokens_per_req = num_tokens // num_reqs
|
|
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
|
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
|
assert sum(num_scheduled_tokens_list) == num_tokens
|
|
assert len(num_scheduled_tokens_list) == num_reqs
|
|
|
|
hidden_states_list = list(
|
|
torch.split(hidden_states, num_scheduled_tokens_list))
|
|
|
|
req_num_tokens = num_tokens // num_reqs
|
|
|
|
dummy_metadata = PoolingMetadata(
|
|
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
|
|
device=self.device),
|
|
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
|
|
dtype=torch.int32,
|
|
device=self.device),
|
|
pooling_params=[PoolingParams()] * num_reqs)
|
|
|
|
try:
|
|
pooler_output = self.model.pooler(hidden_states=hidden_states_list,
|
|
pooling_metadata=dummy_metadata)
|
|
except RuntimeError as e:
|
|
if 'out of memory' in str(e):
|
|
raise RuntimeError(
|
|
"NPU out of memory occurred when warming up pooler with "
|
|
f"{num_reqs} dummy requests. Please try lowering "
|
|
"`max_num_seqs` or `gpu_memory_utilization` when "
|
|
"initializing the engine.") from e
|
|
else:
|
|
raise e
|
|
return pooler_output
|
|
|
|
def load_model(self) -> None:
|
|
logger.info("Starting to load model %s...", self.model_config.model)
|
|
|
|
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
|
self.model = get_model(vllm_config=self.vllm_config)
|
|
|
|
if is_310p():
|
|
from vllm.model_executor.layers.linear import (
|
|
MergedColumnParallelLinear, QKVParallelLinear,
|
|
RowParallelLinear)
|
|
for module in self.model.modules():
|
|
if isinstance(module,
|
|
(MergedColumnParallelLinear,
|
|
QKVParallelLinear, RowParallelLinear)):
|
|
module.weight.data = torch_npu.npu_format_cast(
|
|
module.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
|
|
|
try:
|
|
# For version compatibility, remove this after we abort vllm v0.9.1 support
|
|
from vllm.model_executor.models.interfaces import \
|
|
has_step_pooler # type: ignore
|
|
if has_step_pooler(self.model):
|
|
self.input_batch.logits_processing_needs_token_ids = True
|
|
except ImportError:
|
|
pass
|
|
if self.drafter:
|
|
logger.info("Loading drafter model...")
|
|
if self.use_aux_hidden_state_outputs:
|
|
self.drafter.load_model(self.model)
|
|
self.model.set_aux_hidden_state_layers(
|
|
self.model.get_eagle3_aux_hidden_state_layers())
|
|
else:
|
|
self.drafter.load_model()
|
|
if self.lora_config:
|
|
self.model = self.load_lora_model(self.model,
|
|
self.model_config,
|
|
self.scheduler_config,
|
|
self.lora_config,
|
|
self.device)
|
|
logger.info("Loading model weights took %.4f GB",
|
|
m.consumed_memory / float(2**30))
|
|
|
|
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
|
if batch_size < 0 or batch_size > self.max_num_reqs:
|
|
raise ValueError(
|
|
f"Bad graph batch size:{batch_size}! max_num_reqs:{self.max_num_reqs}"
|
|
)
|
|
|
|
compiled_model = self.torchair_compiled_models.get(
|
|
batch_size
|
|
) if self.use_cached_npu_graph else self.torchair_compiled_model
|
|
|
|
if compiled_model:
|
|
return compiled_model
|
|
|
|
import torchair # type: ignore
|
|
from torchair import patch_for_hcom # type: ignore
|
|
|
|
patch_for_hcom()
|
|
|
|
if is_310p():
|
|
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
|
|
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
|
|
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
|
|
communication_adaptation_310p
|
|
communication_adaptation_310p()
|
|
|
|
config = torchair.CompilerConfig()
|
|
config.experimental_config.frozen_parameter = True
|
|
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
|
# disable it on 300I Duo platform now.
|
|
config.experimental_config.tiling_schedule_optimize = not is_310p()
|
|
config.experimental_config.enable_view_optimize = \
|
|
get_ascend_config().torchair_graph_config.enable_view_optimize
|
|
torch.npu.set_compile_mode(jit_compile=False)
|
|
if not self.use_cached_npu_graph:
|
|
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
|
self.torchair_compiled_model = torch.compile(
|
|
self.model,
|
|
dynamic=True,
|
|
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
|
backend=npu_backend)
|
|
return self.torchair_compiled_model
|
|
else:
|
|
# Generate a new forward proxy code object to prevent the invalidation of
|
|
# compilation cache caused by dynamo retracing
|
|
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
|
forward_fn = self.model.forward
|
|
code = forward_fn.__code__
|
|
# Mark code object with a new proxy name
|
|
modified_code = code.replace(co_name=forward_proxy_name, )
|
|
|
|
modified_func = types.FunctionType(modified_code,
|
|
forward_fn.__globals__,
|
|
name=forward_proxy_name,
|
|
argdefs=forward_fn.__defaults__)
|
|
|
|
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
|
self.model, nn.Module)
|
|
self.torchair_compiled_models[
|
|
batch_size] = torchair.inference.cache_compile(
|
|
self.model.__dict__[forward_proxy_name],
|
|
dynamic=True,
|
|
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
|
config=config,
|
|
ge_cache=False)
|
|
return self.torchair_compiled_models[batch_size]
|
|
|
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
"""
|
|
Initialize KV cache based on `kv_cache_config`.
|
|
Args:
|
|
kv_cache_config: Configuration for the KV cache, including the KV
|
|
cache size of each layer
|
|
"""
|
|
self.kv_cache_config = kv_cache_config
|
|
import torch_npu
|
|
acl_format = ACL_FORMAT_FRACTAL_NZ if is_310p(
|
|
) else ACL_FORMAT_FRACTAL_ND
|
|
kv_caches: Dict[str, torch.Tensor] = {}
|
|
|
|
self.input_batch = InputBatch(
|
|
max_num_reqs=self.max_num_reqs,
|
|
max_model_len=self.model_config.max_model_len,
|
|
max_num_batched_tokens=self.max_num_tokens,
|
|
device=self.device,
|
|
pin_memory=True,
|
|
vocab_size=self.model_config.get_vocab_size(),
|
|
block_sizes=[self.block_size],
|
|
is_spec_decode=bool(self.vllm_config.speculative_config),
|
|
)
|
|
|
|
kv_cache_sizes = {}
|
|
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
|
assert len(kv_cache_tensor.shared_by) == 1, (
|
|
"KV cache tensor shared by multiple layers is not supported in "
|
|
"NPU.")
|
|
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
|
|
|
for kv_cache_group in kv_cache_config.kv_cache_groups:
|
|
kv_cache_spec = kv_cache_group.kv_cache_spec
|
|
for layer_name in kv_cache_group.layer_names:
|
|
tensor_size = kv_cache_sizes[layer_name]
|
|
assert tensor_size % kv_cache_spec.page_size_bytes == 0
|
|
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
|
|
|
# `num_blocks` is the number of blocks the model runner can use.
|
|
# `kv_cache_config.num_blocks` is the number of blocks that
|
|
# KVCacheManager may allocate.
|
|
# Since different GPUs may have different number of layers and
|
|
# different memory capacities, `num_blocks` can be different on
|
|
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
|
# the min of all `num_blocks`. Verify it here.
|
|
assert num_blocks >= kv_cache_config.num_blocks
|
|
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
|
# encounter OOM issue
|
|
if isinstance(kv_cache_spec, FullAttentionSpec):
|
|
if self.vllm_config.additional_config.get(
|
|
"kv_cache_dtype", None) == 'int8':
|
|
kv_cache_shape = self.attn_backend.get_bsh_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads,
|
|
kv_cache_spec.head_size)
|
|
else:
|
|
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads,
|
|
kv_cache_spec.head_size)
|
|
if self.torchair_graph_enabled:
|
|
if len(kv_cache_shape) == 3:
|
|
# for non MLA attention backend that use torchair, we consider to pass kv_cache layout
|
|
# of BSH ([num_blocks, block_size, kv_head_dim * head_size]) to attention.
|
|
|
|
kv_caches[layer_name] = (
|
|
torch.zeros(kv_cache_shape,
|
|
dtype=self.kv_cache_dtype,
|
|
device=self.device),
|
|
torch.zeros(kv_cache_shape,
|
|
dtype=self.kv_cache_dtype,
|
|
device=self.device))
|
|
# atb reshape_and_cache does not support torchair.
|
|
kv_caches[layer_name] = (
|
|
torch_npu.npu_format_cast(
|
|
kv_caches[layer_name][0],
|
|
ACL_FORMAT_FRACTAL_ND),
|
|
torch_npu.npu_format_cast(
|
|
kv_caches[layer_name][1],
|
|
ACL_FORMAT_FRACTAL_ND),
|
|
)
|
|
else:
|
|
# for MLA attention backend that use torchair.
|
|
layer_kv_cache_nope = torch.zeros(
|
|
kv_cache_shape[:-1] +
|
|
(self.model_config.hf_text_config.kv_lora_rank,
|
|
),
|
|
dtype=self.dtype,
|
|
pin_memory=True,
|
|
device=self.device)
|
|
layer_kv_cache_pe = torch.zeros(
|
|
kv_cache_shape[:-1] +
|
|
(self.model_config.hf_text_config.
|
|
qk_rope_head_dim, ),
|
|
dtype=self.dtype,
|
|
pin_memory=True,
|
|
device=self.device)
|
|
kv_caches[layer_name] = (layer_kv_cache_nope,
|
|
layer_kv_cache_pe)
|
|
kv_caches[layer_name] = (
|
|
torch_npu.npu_format_cast(
|
|
kv_caches[layer_name][0], acl_format),
|
|
torch_npu.npu_format_cast(
|
|
kv_caches[layer_name][1], acl_format),
|
|
)
|
|
else:
|
|
kv_caches[layer_name] = torch.zeros(
|
|
kv_cache_shape,
|
|
dtype=self.kv_cache_dtype,
|
|
device=self.device)
|
|
kv_caches[layer_name] = \
|
|
torch_npu.npu_format_cast(kv_caches[layer_name], acl_format)
|
|
else:
|
|
# TODO: add new branches when introducing more types of
|
|
# KV cache specs.
|
|
raise ValueError("Unknown KV cache spec type.")
|
|
|
|
bind_kv_cache(
|
|
kv_caches,
|
|
self.vllm_config.compilation_config.static_forward_context,
|
|
self.kv_caches)
|
|
|
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
|
"""
|
|
Generates the KVCacheSpec by parsing the kv cache format from each
|
|
Attention module in the static forward context.
|
|
Returns:
|
|
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
|
format. Layers that do not need KV cache are not included.
|
|
"""
|
|
|
|
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
|
use_mla = self.vllm_config.model_config.use_mla
|
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
|
for layer_name, attn_module in forward_ctx.items():
|
|
if isinstance(attn_module, FusedMoE):
|
|
continue
|
|
|
|
# TODO: Support other attention modules, e.g., sliding window,
|
|
# cross-attention
|
|
assert isinstance(attn_module, Attention)
|
|
if attn_module.attn_type == AttentionType.DECODER:
|
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
|
block_size=self.block_size,
|
|
num_kv_heads=attn_module.num_kv_heads,
|
|
head_size=attn_module.head_size,
|
|
dtype=self.kv_cache_dtype,
|
|
use_mla=use_mla)
|
|
elif attn_module.attn_type in (AttentionType.ENCODER,
|
|
AttentionType.ENCODER_ONLY):
|
|
# encoder-only attention does not need KV cache.
|
|
continue
|
|
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
|
raise NotImplementedError
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown attention type: {attn_module.attn_type}")
|
|
|
|
return kv_cache_spec
|
|
|
|
def capture_model(self) -> None:
|
|
start_time = time.perf_counter()
|
|
start_free_npu_memory = torch.npu.mem_get_info()[0]
|
|
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
|
|
# torchair graph capture can cause some issues, so now we just
|
|
# temporarily split the codepath for the two different graph patterns.
|
|
if self.torchair_graph_enabled:
|
|
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
|
|
graph_num = len(torchair_graph_batch_sizes)
|
|
logger.info(
|
|
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
|
|
0.5 * graph_num, 1.5 * graph_num)
|
|
# Trigger torchair graph capture for specific shapes.
|
|
# Capture the large shapes first so that the smaller shapes
|
|
# can reuse the memory pool allocated for the large shapes.
|
|
for idx, num_tokens in enumerate(
|
|
reversed(torchair_graph_batch_sizes)):
|
|
for _ in range(self.vllm_config.compilation_config.
|
|
cudagraph_num_of_warmups):
|
|
self._dummy_run(num_tokens,
|
|
is_compile=True,
|
|
with_prefill=False)
|
|
self._dummy_run(num_tokens,
|
|
is_compile=True,
|
|
with_prefill=False)
|
|
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
|
num_tokens, idx + 1, graph_num)
|
|
elif self.use_aclgraph:
|
|
# Trigger ACL graph capture for specific shapes.
|
|
# Capture the large shapes first so that the smaller shapes
|
|
# can reuse the memory pool allocated for the large shapes.
|
|
with graph_capture(device=self.device):
|
|
for num_tokens in reversed(self.aclgraph_batch_sizes):
|
|
for _ in range(self.vllm_config.compilation_config.
|
|
cudagraph_num_of_warmups):
|
|
self._dummy_run(num_tokens)
|
|
self._dummy_run(num_tokens)
|
|
else:
|
|
logger.info("Skipping NPU graph capture for eager mode.")
|
|
return
|
|
end_time = time.perf_counter()
|
|
end_free_npu_memory = torch.npu.mem_get_info()[0]
|
|
elapsed_time = end_time - start_time
|
|
npu_graph_size = start_free_npu_memory - end_free_npu_memory
|
|
# This usually takes 5~20 seconds.
|
|
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
|
elapsed_time, npu_graph_size / (1 << 30))
|
|
|
|
def _generate_draft_token_ids(
|
|
self,
|
|
sampled_token_ids: list[list[int]],
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> list[list[int]]:
|
|
# TODO(woosuk): Optimize.
|
|
draft_token_ids: list[list[int]] = []
|
|
for i, sampled_ids in enumerate(sampled_token_ids):
|
|
num_sampled_ids = len(sampled_ids)
|
|
if not num_sampled_ids:
|
|
# Skip speculative decoding.
|
|
draft_token_ids.append([])
|
|
continue
|
|
|
|
# Skip requests that require top-p, top-k, etc.
|
|
req_id = self.input_batch.req_ids[i]
|
|
if vllm_version_is("0.9.1"):
|
|
if not is_spec_decode_supported(req_id, self.input_batch):
|
|
draft_token_ids.append([])
|
|
continue
|
|
else:
|
|
if req_id in self.input_batch.spec_decode_unsupported_reqs:
|
|
draft_token_ids.append([])
|
|
continue
|
|
|
|
# Add sampled_token_ids to token_ids_cpu.
|
|
start_idx = self.input_batch.num_tokens_no_spec[i]
|
|
end_idx = start_idx + num_sampled_ids
|
|
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
|
|
assert self.drafter is not None
|
|
drafter_output = self.drafter.propose(
|
|
self.input_batch.token_ids_cpu[i, :end_idx])
|
|
if drafter_output is None or len(drafter_output) == 0:
|
|
draft_token_ids.append([])
|
|
else:
|
|
draft_token_ids.append(drafter_output.tolist())
|
|
return draft_token_ids
|
|
|
|
def _generate_mtp_token_ids(
|
|
self,
|
|
valid_sampled_token_ids: list[list[int]],
|
|
sampling_metadata: SamplingMetadata,
|
|
scheduler_output: "SchedulerOutput",
|
|
spec_decode_metadata: SpecDecodeMetadata,
|
|
positions: torch.Tensor,
|
|
num_scheduled_tokens: int,
|
|
hidden_states: torch.Tensor,
|
|
attn_metadata: SpecDecodeMetadata,
|
|
):
|
|
next_token_ids: list[int] = []
|
|
for i, token_ids in enumerate(valid_sampled_token_ids):
|
|
if token_ids:
|
|
# Common case.
|
|
next_token_id = token_ids[-1]
|
|
else:
|
|
# Partial prefill (rare case).
|
|
# Get the next token id from the request state.
|
|
req_id = self.input_batch.req_ids[i]
|
|
req_state = self.requests[req_id]
|
|
seq_len = (req_state.num_computed_tokens +
|
|
scheduler_output.num_scheduled_tokens[req_id])
|
|
next_token_id = req_state.get_token_id(seq_len)
|
|
next_token_ids.append(next_token_id)
|
|
next_token_ids = torch.tensor(next_token_ids,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
|
|
if spec_decode_metadata is None:
|
|
# input_ids can be None for multimodal models.
|
|
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
|
target_positions = positions[:num_scheduled_tokens]
|
|
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
|
target_slot_mapping = attn_metadata.slot_mapping
|
|
cu_num_tokens = attn_metadata.query_start_loc
|
|
else:
|
|
# TODO(woosuk): Refactor this.
|
|
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
|
num_rejected_tokens = [
|
|
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
|
for i, n in enumerate(num_draft_tokens)
|
|
]
|
|
num_rejected_tokens = torch.tensor(
|
|
num_rejected_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
assert self.drafter is not None
|
|
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
|
attn_metadata.query_start_loc,
|
|
num_rejected_tokens,
|
|
)
|
|
target_token_ids = self.input_ids[token_indices]
|
|
target_positions = positions[token_indices]
|
|
target_hidden_states = hidden_states[token_indices]
|
|
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
|
|
assert self.drafter is not None
|
|
draft_token_ids = self.drafter.propose(
|
|
target_token_ids=target_token_ids,
|
|
target_positions=target_positions,
|
|
target_hidden_states=target_hidden_states,
|
|
target_slot_mapping=target_slot_mapping,
|
|
next_token_ids=next_token_ids,
|
|
cu_num_tokens=cu_num_tokens,
|
|
block_table=attn_metadata.block_tables,
|
|
sampling_metadata=sampling_metadata,
|
|
)
|
|
spec_token_ids = draft_token_ids.tolist()
|
|
return spec_token_ids
|
|
|
|
def _get_prompt_logprobs_dict(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> dict[str, Optional[LogprobsTensors]]:
|
|
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
|
|
if not num_prompt_logprobs_dict:
|
|
return {}
|
|
|
|
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
|
|
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
|
|
|
|
# Since prompt logprobs are a rare feature, prioritize simple,
|
|
# maintainable loop over optimal performance.
|
|
completed_prefill_reqs = []
|
|
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
|
|
|
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
|
|
|
# Get metadata for this request.
|
|
request = self.requests[req_id]
|
|
num_prompt_tokens = len(request.prompt_token_ids)
|
|
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
|
|
self.device, non_blocking=True)
|
|
|
|
# Set up target LogprobsTensors object.
|
|
logprobs_tensors = in_progress_dict.get(req_id)
|
|
if not logprobs_tensors:
|
|
# Create empty logprobs CPU tensors for the entire prompt.
|
|
# If chunked, we'll copy in slice by slice.
|
|
logprobs_tensors = LogprobsTensors.empty_cpu(
|
|
num_prompt_tokens - 1, num_prompt_logprobs + 1)
|
|
in_progress_dict[req_id] = logprobs_tensors
|
|
|
|
# Determine number of logits to retrieve.
|
|
start_idx = request.num_computed_tokens
|
|
start_tok = start_idx + 1
|
|
num_remaining_tokens = num_prompt_tokens - start_tok
|
|
if num_tokens <= num_remaining_tokens:
|
|
# This is a chunk, more tokens remain.
|
|
# In the == case, there are no more prompt logprobs to produce
|
|
# but we want to defer returning them to the next step where we
|
|
# have new generated tokens to return.
|
|
num_logits = num_tokens
|
|
else:
|
|
# This is the last chunk of prompt tokens to return.
|
|
num_logits = num_remaining_tokens
|
|
completed_prefill_reqs.append(req_id)
|
|
prompt_logprobs_dict[req_id] = logprobs_tensors
|
|
|
|
if num_logits <= 0:
|
|
# This can happen for the final chunk if we prefilled exactly
|
|
# (num_prompt_tokens - 1) tokens for this request in the prior
|
|
# step. There are no more prompt logprobs to produce.
|
|
continue
|
|
|
|
# Get the logits corresponding to this req's prompt tokens.
|
|
# If this is a partial request (i.e. chunked prefill),
|
|
# then there is prompt logprob generated for each index.
|
|
req_idx = self.input_batch.req_id_to_index[req_id]
|
|
offset = self.query_start_loc_np[req_idx].item()
|
|
prompt_hidden_states = hidden_states[offset:offset + num_logits]
|
|
logits = self.model.compute_logits(prompt_hidden_states, None)
|
|
|
|
# Get the "target" tokens for each index. For prompt at index i,
|
|
# the token at prompt index i+1 is the "sampled" token we want
|
|
# to gather the logprob for.
|
|
tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]
|
|
|
|
# Compute prompt logprobs.
|
|
logprobs = self.sampler.compute_logprobs(logits)
|
|
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
|
|
logprobs, num_prompt_logprobs, tgt_token_ids)
|
|
|
|
# Transfer NPU->CPU async.
|
|
chunk_slice = slice(start_idx, start_idx + num_logits)
|
|
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
|
|
token_ids, non_blocking=True)
|
|
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
|
|
non_blocking=True)
|
|
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
|
|
ranks, non_blocking=True)
|
|
|
|
# Remove requests that have completed prefill from the batch
|
|
# num_prompt_logprobs_dict.
|
|
for req_id in completed_prefill_reqs:
|
|
del num_prompt_logprobs_dict[req_id]
|
|
del in_progress_dict[req_id]
|
|
|
|
# Must synchronize the non-blocking NPU->CPU transfers.
|
|
if prompt_logprobs_dict:
|
|
torch.npu.synchronize()
|
|
|
|
return prompt_logprobs_dict
|
|
|
|
def init_torchair_graph_batch_sizes(self):
|
|
start_graph_batch_size = 4
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
|
|
start_graph_batch_size = max(start_graph_batch_size, tp_size)
|
|
|
|
while (start_graph_batch_size <= self.max_num_reqs):
|
|
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
|
|
start_graph_batch_size *= 2
|
|
|
|
def select_torchair_padded_batch_size(self, batch_size: int):
|
|
selected_batch_size = self.max_num_reqs
|
|
for padded_batch_size in self.torchair_graph_batch_sizes:
|
|
if batch_size <= padded_batch_size < selected_batch_size:
|
|
selected_batch_size = padded_batch_size
|
|
return selected_batch_size
|