### What this PR does / why we need it?
This PR implements Context Parallelism (CP) support for the Qwen3-Next
model, including PCP (Parallel Context Parallelism) and DCP
(Dynamic/Data Context Parallelism).
- vLLM version: v0.15.0
- vLLM main:
f176443446
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Co-authored-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
992 lines
40 KiB
Python
992 lines
40 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import ClassVar
|
|
|
|
import torch
|
|
import torch_npu
|
|
import vllm.envs as envs_vllm
|
|
from vllm.config import VllmConfig, get_current_vllm_config
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
|
from vllm.utils.math_utils import cdiv
|
|
from vllm.v1.attention.backend import ( # type: ignore
|
|
AttentionBackend,
|
|
AttentionCGSupport,
|
|
AttentionImpl,
|
|
AttentionLayer,
|
|
AttentionMetadataBuilder,
|
|
AttentionType,
|
|
)
|
|
from vllm.v1.attention.backends.registry import ( # type: ignore
|
|
AttentionBackendEnum,
|
|
register_backend,
|
|
)
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
|
|
|
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
|
from vllm_ascend.attention.context_parallel.common_cp import AscendMetadataForDecode, AscendMetadataForPrefill
|
|
from vllm_ascend.attention.utils import (
|
|
AscendCommonAttentionMetadata,
|
|
enable_cp,
|
|
split_decodes_and_prefills,
|
|
using_paged_attention,
|
|
)
|
|
from vllm_ascend.compilation.acl_graph import (
|
|
get_draft_graph_params,
|
|
get_graph_params,
|
|
update_draft_graph_params_workspaces,
|
|
update_graph_params_workspaces,
|
|
)
|
|
from vllm_ascend.device.device_op import DeviceOperator
|
|
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
|
|
from vllm_ascend.utils import weak_ref_tensors
|
|
|
|
# default max value of sliding window size
|
|
SWA_INT_MAX = 2147483647
|
|
|
|
|
|
@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND")
|
|
class AscendAttentionBackend(AttentionBackend):
|
|
accept_output_buffer: bool = True
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
# HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
|
|
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
|
|
# rectify this when vllm disable the assertion.
|
|
return "CUSTOM" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> type["AscendAttentionBackendImpl"]:
|
|
if enable_cp():
|
|
from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPImpl
|
|
|
|
return AscendAttentionCPImpl
|
|
return AscendAttentionBackendImpl
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
|
if enable_cp():
|
|
from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPMetadataBuilder
|
|
|
|
return AscendAttentionCPMetadataBuilder
|
|
return AscendAttentionMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
) -> tuple[int, ...]:
|
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
|
|
|
@staticmethod
|
|
def swap_blocks(
|
|
src_kv_cache: list[torch.Tensor],
|
|
dst_kv_cache: list[torch.Tensor],
|
|
src_to_dst: torch.Tensor,
|
|
) -> None:
|
|
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
|
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
|
src_indices = src_to_dst[:, 0]
|
|
dst_indices = src_to_dst[:, 1]
|
|
|
|
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(dst_key_cache.device)
|
|
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(dst_key_cache.device)
|
|
|
|
@staticmethod
|
|
def copy_blocks(
|
|
kv_caches: list[torch.Tensor],
|
|
src_to_dists: torch.Tensor,
|
|
) -> None:
|
|
src_indices = src_to_dists[:, 0]
|
|
dst_indices = src_to_dists[:, 1]
|
|
|
|
for kv_cache in kv_caches:
|
|
key_caches = kv_cache[0]
|
|
value_caches = kv_cache[1]
|
|
key_caches[dst_indices] = key_caches[src_indices]
|
|
value_caches[dst_indices] = value_caches[src_indices]
|
|
|
|
@staticmethod
|
|
def get_supported_kernel_block_sizes() -> list[int]:
|
|
return [128]
|
|
|
|
|
|
class AscendAttentionState(Enum):
|
|
PrefillNoCache = 0
|
|
PrefillCacheHit = 1
|
|
DecodeOnly = 2
|
|
ChunkedPrefill = 3
|
|
SpecDecoding = 4
|
|
|
|
|
|
@dataclass
|
|
class AscendMetadata:
|
|
"""
|
|
Per-layer attention metadata for Ascend FlashAttention backend.
|
|
|
|
Contains attention masks, token counts, sequence lengths and KV cache
|
|
related properties for attention computation.
|
|
"""
|
|
|
|
# **************************** Basic Properties ************************** #
|
|
attn_mask: torch.Tensor | None = None
|
|
# Current state of this attention run.
|
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
|
|
|
# Number of tokens excluding padding.
|
|
num_actual_tokens_pcp_padded: int = 0
|
|
num_actual_tokens: int = 0
|
|
num_decode_tokens: int = 0
|
|
num_prefills: int = 0
|
|
num_decodes: int = 0
|
|
num_decodes_flatten: int = 0
|
|
|
|
# The sequence length per sequence. Sequence length means the computed
|
|
# tokens + new tokens (is None if it is a decoding).
|
|
# (batch_size,)
|
|
# TODO(Angazenn): The following parameters are quite redundant and
|
|
# contains similar information (such as seq_lens seq_lens_list). We
|
|
# should simplified these parameters once attention schema in vLLM-Ascend
|
|
# is unified.
|
|
seq_lens: torch.Tensor = None
|
|
seq_lens_list: list[int] = None # type: ignore
|
|
actual_seq_lengths_q: list[int] = None # type: ignore
|
|
|
|
query_start_loc: torch.Tensor = None
|
|
# Maximum query length in the batch (None for decoding).
|
|
max_query_len: int | None = None
|
|
|
|
# ********************** KV Cache Related Properties ********************* #
|
|
# Block addresses per sequence (Seq id -> list of physical block).
|
|
# (batch_size, max_blocks_per_seq)
|
|
block_tables: torch.Tensor = None
|
|
|
|
# The indices of the token slots that input tokens will be stored into.
|
|
# E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the
|
|
# three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0,
|
|
# and 1st slot in block 1, respectively.
|
|
# (num_tokens,)
|
|
slot_mapping: torch.Tensor = None
|
|
# pcp
|
|
prefill: AscendMetadataForPrefill | None = None
|
|
# dcp
|
|
decode_meta: AscendMetadataForDecode | None = None
|
|
|
|
causal: bool = True
|
|
# runner_type in model_config.
|
|
model_runner_type: str = ""
|
|
# prefill reshape_and_cache event
|
|
reshape_cache_event: torch.npu.Event = None
|
|
|
|
# sliding window attention mask
|
|
swa_mask: torch.Tensor | None = None
|
|
|
|
|
|
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
|
"""
|
|
Builder for constructing AscendMetadata from CommonAttentionMetadata.
|
|
|
|
Handles attention mask generation and metadata preparation for
|
|
Ascend FlashAttention backend.
|
|
"""
|
|
|
|
# Does this backend/builder reorder the batch?
|
|
# If not, set this to None. Otherwise set it to the query
|
|
# length that will be pulled into the front of the batch.
|
|
reorder_batch_threshold: ClassVar[int] = 1
|
|
|
|
def __init__(
|
|
self,
|
|
kv_cache_spec: AttentionSpec,
|
|
layer_names: list[str],
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
):
|
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
self.compilation_config = vllm_config.compilation_config
|
|
self.device = device
|
|
self.max_num_blocks_per_req = cdiv(
|
|
self.model_config.max_model_len, AscendAttentionBackend.get_supported_kernel_block_sizes()[0]
|
|
)
|
|
|
|
self.speculative_config = vllm_config.speculative_config
|
|
self.decode_threshold = 1
|
|
if self.speculative_config:
|
|
spec_token_num = self.speculative_config.num_speculative_tokens
|
|
self.decode_threshold += spec_token_num
|
|
assert self.decode_threshold <= 16, (
|
|
f"decode_threshold exceeded \
|
|
npu_fused_infer_attention_score TND layout's limit of 16, \
|
|
got {self.decode_threshold}"
|
|
)
|
|
|
|
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
|
|
|
|
scheduler_config = vllm_config.scheduler_config
|
|
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
|
|
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
|
|
|
@classmethod
|
|
def get_cudagraph_support(
|
|
cls: type["AscendAttentionMetadataBuilder"],
|
|
vllm_config: VllmConfig,
|
|
kv_cache_spec: AttentionSpec,
|
|
) -> AttentionCGSupport:
|
|
# Explicit override in case the underlying builder specialized this getter.
|
|
# @override omitted only because of mypy limitation due to type variable.
|
|
return AttentionCGSupport.ALWAYS
|
|
|
|
def reorder_batch(self, input_batch, scheduler_output: "SchedulerOutput") -> bool:
|
|
return False
|
|
|
|
def build(
|
|
self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
|
fast_build: bool = False,
|
|
) -> AscendMetadata:
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
|
|
|
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = split_decodes_and_prefills(
|
|
common_attn_metadata, decode_threshold=self.decode_threshold
|
|
)
|
|
|
|
block_table = common_attn_metadata.block_table_tensor
|
|
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
|
|
|
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
|
# this slot_mapping override doesn't work since vllm will override it again. We should fix it vllm.
|
|
# see: https://github.com/vllm-project/vllm/blob/ce88756b967c2c5006746a424c15dd59a284ed8c/vllm/model_executor/layers/attention/cross_attention.py#L117
|
|
if isinstance(self.kv_cache_spec, CrossAttentionSpec):
|
|
seq_lens = common_attn_metadata.seq_lens
|
|
slot_mapping = common_attn_metadata.slot_mapping.to(torch.int32)
|
|
attn_state = common_attn_metadata.attn_state
|
|
|
|
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
|
|
attn_mask = self.attn_mask_builder.get_attention_mask(self.model_config)
|
|
|
|
swa_mask = None
|
|
is_swa = hasattr(self.model_config.hf_text_config, "sliding_window")
|
|
if self.model_config is not None and is_swa:
|
|
swa_mask = self.attn_mask_builder.get_swa_mask(
|
|
self.model_config.dtype, self.model_config.hf_text_config.sliding_window
|
|
)
|
|
|
|
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
|
|
query_start_loc = query_start_loc_cpu.pin_memory().to(self.device, non_blocking=True)
|
|
|
|
attn_metadata = AscendMetadata(
|
|
num_actual_tokens=num_actual_tokens,
|
|
num_decode_tokens=num_decode_tokens,
|
|
block_tables=block_table,
|
|
query_start_loc=query_start_loc,
|
|
seq_lens=seq_lens,
|
|
seq_lens_list=seq_lens.tolist(),
|
|
max_query_len=common_attn_metadata.max_query_len,
|
|
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
|
slot_mapping=slot_mapping,
|
|
attn_mask=attn_mask,
|
|
swa_mask=swa_mask,
|
|
attn_state=attn_state,
|
|
num_prefills=num_prefills,
|
|
num_decodes=num_decodes,
|
|
causal=common_attn_metadata.causal,
|
|
model_runner_type=self.model_config.runner_type,
|
|
)
|
|
return attn_metadata
|
|
|
|
def build_for_graph_capture(
|
|
self,
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
|
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
|
):
|
|
if attn_state in (AscendAttentionState.DecodeOnly, AscendAttentionState.ChunkedPrefill):
|
|
attn_metadata = self.build(
|
|
common_prefix_len=0,
|
|
common_attn_metadata=common_attn_metadata,
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
"Currently we only support building dummy metadata for DecodeOnly and ChunkedPrefill state"
|
|
)
|
|
|
|
attn_metadata.attn_state = attn_state
|
|
return attn_metadata
|
|
|
|
|
|
class AscendAttentionBackendImpl(AttentionImpl):
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: list[float] | None,
|
|
sliding_window: int | None,
|
|
kv_cache_dtype: str,
|
|
logits_soft_cap: float | None,
|
|
attn_type: str,
|
|
kv_sharing_target_layer_name: str | None,
|
|
sinks: torch.Tensor = None,
|
|
**kwargs,
|
|
) -> None:
|
|
self.vllm_config = get_current_vllm_config()
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = float(scale)
|
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
self.hidden_size = self.num_heads * self.head_size
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
self.sliding_window = sliding_window
|
|
if alibi_slopes is not None:
|
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32, device="npu")
|
|
self.alibi_slopes = alibi_slopes
|
|
self.attn_type = attn_type
|
|
|
|
assert self.num_heads % self.num_kv_heads == 0
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
self.key_cache = None
|
|
self.value_cache = None
|
|
self.is_kv_producer = (
|
|
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
|
)
|
|
self.sinks = sinks
|
|
|
|
@staticmethod
|
|
def update_graph_params(
|
|
update_stream,
|
|
forward_context,
|
|
num_tokens,
|
|
vllm_config,
|
|
speculative_config=None,
|
|
num_dcp_pcp_tokens=None,
|
|
draft_attn_metadatas=None,
|
|
):
|
|
if using_paged_attention(num_tokens, vllm_config):
|
|
# Paged Attention update logic
|
|
if forward_context.is_draft_model:
|
|
graph_params = get_draft_graph_params()
|
|
else:
|
|
graph_params = get_graph_params()
|
|
with torch.npu.stream(update_stream):
|
|
for key, param, handle, event in zip(
|
|
forward_context.attn_metadata,
|
|
graph_params.attn_params[num_tokens],
|
|
graph_params.handles[num_tokens],
|
|
graph_params.events[num_tokens],
|
|
):
|
|
(
|
|
query,
|
|
key_cache,
|
|
value_cache,
|
|
num_kv_heads,
|
|
num_heads,
|
|
scale,
|
|
block_table,
|
|
seq_lens,
|
|
output,
|
|
) = param
|
|
seq_lens = forward_context.attn_metadata[key].seq_lens
|
|
|
|
workspace = torch_npu._npu_paged_attention_get_workspace(
|
|
query=query,
|
|
key_cache=key_cache,
|
|
value_cache=value_cache,
|
|
num_kv_heads=num_kv_heads,
|
|
num_heads=num_heads,
|
|
scale_value=scale,
|
|
block_table=block_table,
|
|
context_lens=seq_lens,
|
|
out=output,
|
|
)
|
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
|
torch_npu._npu_paged_attention(
|
|
query=query,
|
|
key_cache=key_cache,
|
|
value_cache=value_cache,
|
|
num_kv_heads=num_kv_heads,
|
|
num_heads=num_heads,
|
|
scale_value=scale,
|
|
block_table=block_table,
|
|
context_lens=seq_lens,
|
|
out=output,
|
|
workspace=workspace,
|
|
)
|
|
torch.npu.graph_task_update_end(update_stream)
|
|
event.record(update_stream)
|
|
else:
|
|
# FIA update logic
|
|
if forward_context.is_draft_model:
|
|
graph_params = get_draft_graph_params()
|
|
attn_metadata = draft_attn_metadatas
|
|
attn_keys = list(attn_metadata[0].keys())
|
|
else:
|
|
graph_params = get_graph_params()
|
|
attn_metadata = forward_context.attn_metadata
|
|
attn_keys = list(attn_metadata.keys())
|
|
# For Qwen3-next, since the kv_cache_config has already categorized
|
|
# linear_attn and self_attn, the attn_metadata is first arranged with
|
|
# self_attn followed by linear_attn. Therefore, using zip directly
|
|
# filters out the update operations for linear_attn.
|
|
# TODO: We use a new variable `attn_keys` to ensure the loop count is
|
|
# correct after get by `zip` because of the new structure of the attn_metadata
|
|
# when running with the merged full eagle-graph. Should check it with Qwen3-next.
|
|
num_layers = len(attn_keys)
|
|
if num_layers == 0:
|
|
return
|
|
if forward_context.is_draft_model:
|
|
attn_keys = attn_keys * (len(graph_params.attn_params[num_tokens]) // num_layers)
|
|
attn_count = 0
|
|
with torch.npu.stream(update_stream):
|
|
for key, param, handle, event in zip(
|
|
attn_keys,
|
|
graph_params.attn_params[num_tokens],
|
|
graph_params.handles[num_tokens],
|
|
graph_params.events[num_tokens],
|
|
):
|
|
(
|
|
query,
|
|
key_cache,
|
|
value,
|
|
block_tables,
|
|
attn_mask,
|
|
block_size,
|
|
seq_lens,
|
|
query_start_loc,
|
|
num_kv_heads,
|
|
num_heads,
|
|
scale,
|
|
attn_output,
|
|
softmax_lse,
|
|
) = param
|
|
|
|
if forward_context.is_draft_model:
|
|
draft_step = attn_count // num_layers
|
|
seq_lens = attn_metadata[draft_step][key].seq_lens_list
|
|
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
|
|
attn_count = attn_count + 1
|
|
else:
|
|
seq_lens = attn_metadata[key].seq_lens_list
|
|
actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q
|
|
|
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
|
torch_npu.npu_fused_infer_attention_score.out(
|
|
query=query,
|
|
key=key_cache,
|
|
value=value,
|
|
block_table=block_tables,
|
|
atten_mask=attn_mask,
|
|
input_layout="TND",
|
|
block_size=block_size,
|
|
actual_seq_lengths=actual_seq_lengths_q,
|
|
actual_seq_lengths_kv=seq_lens,
|
|
num_key_value_heads=num_kv_heads,
|
|
num_heads=num_heads,
|
|
scale=scale,
|
|
sparse_mode=3,
|
|
workspace=graph_params.workspaces.get(num_tokens),
|
|
out=[attn_output, softmax_lse],
|
|
)
|
|
torch.npu.graph_task_update_end(update_stream)
|
|
|
|
event.record(update_stream)
|
|
|
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
|
super().process_weights_after_loading(act_dtype)
|
|
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
|
flashcomm2_oshard_manager.post_process_after_loading()
|
|
|
|
def full_graph_fia(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attn_metadata: AscendMetadata,
|
|
output: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
|
|
|
|
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
|
forward_context = get_forward_context()
|
|
if forward_context.is_draft_model:
|
|
graph_params = get_draft_graph_params()
|
|
else:
|
|
graph_params = get_graph_params()
|
|
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q
|
|
# Prepare tensors for attention output
|
|
# TODO: Refactor this to step-level instead of layer-level
|
|
|
|
# Get workspace from cache or calculate it if not present.
|
|
workspace = graph_params.workspaces.get(num_tokens)
|
|
softmax_lse = torch.empty(1, dtype=query.dtype, device=query.device)
|
|
if workspace is None:
|
|
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
atten_mask=attn_metadata.attn_mask,
|
|
block_table=block_table,
|
|
input_layout="TND",
|
|
block_size=block_size,
|
|
actual_seq_lengths=actual_seq_lengths_q,
|
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
|
num_key_value_heads=self.num_kv_heads,
|
|
num_heads=self.num_heads,
|
|
sparse_mode=3,
|
|
scale=self.scale,
|
|
)
|
|
if forward_context.is_draft_model:
|
|
update_draft_graph_params_workspaces(num_tokens, workspace)
|
|
else:
|
|
update_graph_params_workspaces(num_tokens, workspace)
|
|
|
|
# Handle graph capturing mode
|
|
stream = torch_npu.npu.current_stream()
|
|
|
|
event = torch.npu.ExternalEvent()
|
|
event.wait(stream)
|
|
event.reset(stream)
|
|
graph_params.events[num_tokens].append(event)
|
|
graph_params.attn_params[num_tokens].append(
|
|
(
|
|
weak_ref_tensors(query),
|
|
weak_ref_tensors(key),
|
|
weak_ref_tensors(value),
|
|
weak_ref_tensors(block_table),
|
|
weak_ref_tensors(attn_metadata.attn_mask),
|
|
block_size,
|
|
actual_seq_lengths_kv,
|
|
actual_seq_lengths_q,
|
|
self.num_kv_heads,
|
|
self.num_heads,
|
|
self.scale,
|
|
weak_ref_tensors(output),
|
|
weak_ref_tensors(softmax_lse),
|
|
)
|
|
)
|
|
|
|
torch.npu.graph_task_group_begin(stream)
|
|
torch_npu.npu_fused_infer_attention_score.out(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
atten_mask=attn_metadata.attn_mask,
|
|
block_table=block_table,
|
|
input_layout="TND",
|
|
block_size=block_size,
|
|
actual_seq_lengths=actual_seq_lengths_q,
|
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
|
num_key_value_heads=self.num_kv_heads,
|
|
num_heads=self.num_heads,
|
|
scale=self.scale,
|
|
sparse_mode=3,
|
|
workspace=workspace,
|
|
out=[output, softmax_lse],
|
|
)
|
|
|
|
output = output.view(num_tokens, self.num_heads, self.head_size)
|
|
|
|
handle = torch.npu.graph_task_group_end(stream)
|
|
graph_params.handles[num_tokens].append(handle)
|
|
return output, num_tokens
|
|
|
|
def full_graph_pa(
|
|
self,
|
|
query: torch.Tensor,
|
|
attn_metadata: AscendMetadata,
|
|
output: torch.Tensor | None = None,
|
|
):
|
|
graph_params = get_graph_params()
|
|
forward_context: ForwardContext = get_forward_context()
|
|
num_tokens = query.shape[0]
|
|
if forward_context.capturing:
|
|
# Get workspace from cache or calculate it if not present.
|
|
workspace = graph_params.workspaces.get(num_tokens)
|
|
if workspace is None:
|
|
workspace = torch_npu._npu_paged_attention_get_workspace(
|
|
query=query,
|
|
key_cache=self.key_cache,
|
|
value_cache=self.value_cache,
|
|
num_kv_heads=self.num_kv_heads,
|
|
num_heads=self.num_heads,
|
|
scale_value=self.scale,
|
|
block_table=attn_metadata.block_tables,
|
|
context_lens=attn_metadata.seq_lens,
|
|
out=output,
|
|
)
|
|
update_graph_params_workspaces(num_tokens, workspace)
|
|
|
|
# Handle graph capturing mode
|
|
stream = torch_npu.npu.current_stream()
|
|
|
|
event = torch.npu.ExternalEvent()
|
|
event.wait(stream)
|
|
event.reset(stream)
|
|
graph_params.events[num_tokens].append(event)
|
|
graph_params.attn_params[num_tokens].append(
|
|
(
|
|
weak_ref_tensors(query),
|
|
weak_ref_tensors(self.key_cache),
|
|
weak_ref_tensors(self.value_cache),
|
|
self.num_kv_heads,
|
|
self.num_heads,
|
|
self.scale,
|
|
attn_metadata.block_tables,
|
|
attn_metadata.seq_lens,
|
|
weak_ref_tensors(output),
|
|
)
|
|
)
|
|
|
|
torch.npu.graph_task_group_begin(stream)
|
|
torch_npu._npu_paged_attention(
|
|
query=query,
|
|
key_cache=self.key_cache,
|
|
value_cache=self.value_cache,
|
|
num_kv_heads=self.num_kv_heads,
|
|
num_heads=self.num_heads,
|
|
scale_value=self.scale,
|
|
block_table=attn_metadata.block_tables,
|
|
context_lens=attn_metadata.seq_lens,
|
|
out=output,
|
|
workspace=workspace,
|
|
)
|
|
handle = torch.npu.graph_task_group_end(stream)
|
|
graph_params.handles[num_tokens].append(handle)
|
|
return output
|
|
|
|
def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata):
|
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
|
block_size = 128
|
|
block_table = None
|
|
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
|
|
if self.attn_type == AttentionType.ENCODER_DECODER:
|
|
actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens, dim=0).tolist()
|
|
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
|
batch_size = attn_metadata.seq_lens.shape[0]
|
|
block_table = attn_metadata.block_tables[:batch_size, :]
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
key = self.key_cache.view( # type: ignore
|
|
num_block, block_size, -1
|
|
)
|
|
value = self.value_cache.view( # type: ignore
|
|
num_block, block_size, -1
|
|
)
|
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
key = self.key_cache.view( # type: ignore
|
|
num_block, block_size, -1
|
|
)
|
|
value = self.value_cache.view( # type: ignore
|
|
num_block, block_size, -1
|
|
)
|
|
block_table = attn_metadata.block_tables
|
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
# chunked prefill.
|
|
else:
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
key = self.key_cache.view( # type: ignore
|
|
num_block, block_size, -1
|
|
)
|
|
value = self.value_cache.view( # type: ignore
|
|
num_block, block_size, -1
|
|
)
|
|
block_table = attn_metadata.block_tables
|
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
return key, value, block_size, block_table, actual_seq_lengths_kv
|
|
|
|
def _forward_fia_slidingwindow(self, query: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor):
|
|
batch_size = attn_metadata.seq_lens.shape[0]
|
|
block_size = 128
|
|
query = query.view(batch_size, 1, self.num_heads * self.head_size)
|
|
key = self.key_cache
|
|
value = self.value_cache
|
|
if self.key_cache is not None and self.value_cache is not None:
|
|
block_size = self.key_cache.shape[1]
|
|
key = self.key_cache.flatten(2, 3).contiguous()
|
|
value = self.value_cache.flatten(2, 3).contiguous()
|
|
|
|
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
|
query,
|
|
key,
|
|
value,
|
|
num_heads=self.num_heads,
|
|
num_key_value_heads=self.num_kv_heads,
|
|
input_layout="BSH",
|
|
block_size=block_size,
|
|
pre_tokens=self.sliding_window,
|
|
scale=self.scale,
|
|
block_table=attn_metadata.block_tables,
|
|
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
|
|
actual_seq_lengths_kv=attn_metadata.seq_lens,
|
|
)
|
|
|
|
attn_output = attn_output.view(batch_size, self.num_heads, self.head_size)
|
|
output[:batch_size] = attn_output[:batch_size]
|
|
return output
|
|
|
|
def forward_fused_infer_attention(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attn_metadata: AscendMetadata,
|
|
output: torch.Tensor,
|
|
):
|
|
forward_context: ForwardContext = get_forward_context()
|
|
# we inherit ForwardContext in model runner v2, when enable model
|
|
# runner v2, there is not capturing attribute in forward_context,
|
|
# just use getattr to avoid attribute error.
|
|
if getattr(forward_context, "capturing", False):
|
|
attn_output, num_tokens = self.full_graph_fia(query, key, value, attn_metadata, output)
|
|
output[:num_tokens] = attn_output[:num_tokens]
|
|
return output
|
|
if (
|
|
attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
|
and self.sliding_window is not None
|
|
and attn_metadata.seq_lens.shape[0] == query.size(0)
|
|
and self.sinks is None
|
|
):
|
|
return self._forward_fia_slidingwindow(query, attn_metadata, output)
|
|
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
|
|
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
|
query = query[:num_tokens]
|
|
if (
|
|
attn_metadata.attn_state == AscendAttentionState.PrefillNoCache
|
|
and self.attn_type != AttentionType.ENCODER_DECODER
|
|
):
|
|
key = key[:num_tokens]
|
|
value = value[:num_tokens]
|
|
# Get workspace from cache or calculate it if not present.
|
|
if self.sinks is not None:
|
|
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
|
|
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
|
actual_seq_qlen = torch.tensor([1] * len(attn_metadata.seq_lens_list), dtype=torch.int32).cumsum(dim=0)
|
|
if self.sliding_window is not None:
|
|
atten_mask = attn_metadata.swa_mask
|
|
sparse_mode = 4
|
|
else:
|
|
atten_mask = attn_metadata.attn_mask
|
|
sparse_mode = 3
|
|
attn_output, _ = torch_npu.npu_fused_infer_attention_score_v2(
|
|
query,
|
|
key,
|
|
value,
|
|
num_query_heads=self.num_heads,
|
|
num_key_value_heads=self.num_kv_heads,
|
|
input_layout="TND",
|
|
pre_tokens=self.sliding_window if self.sliding_window is not None else SWA_INT_MAX,
|
|
next_tokens=0,
|
|
atten_mask=atten_mask,
|
|
sparse_mode=sparse_mode,
|
|
softmax_scale=self.scale,
|
|
block_table=block_table,
|
|
block_size=block_size,
|
|
actual_seq_qlen=actual_seq_qlen,
|
|
actual_seq_kvlen=actual_seq_lengths_kv,
|
|
learnable_sink=self.sinks,
|
|
)
|
|
else:
|
|
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
atten_mask=attn_metadata.attn_mask,
|
|
block_table=block_table,
|
|
input_layout="TND",
|
|
block_size=block_size,
|
|
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
|
num_key_value_heads=self.num_kv_heads,
|
|
num_heads=self.num_heads,
|
|
scale=self.scale,
|
|
sparse_mode=3,
|
|
)
|
|
|
|
attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
|
|
output[:num_tokens] = attn_output[:num_tokens]
|
|
return output
|
|
|
|
def forward_paged_attention(
|
|
self,
|
|
query: torch.Tensor,
|
|
attn_metadata: AscendMetadata,
|
|
output: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
forward_context: ForwardContext = get_forward_context()
|
|
if forward_context.capturing:
|
|
return self.full_graph_pa(query, attn_metadata, output)
|
|
torch_npu._npu_paged_attention(
|
|
query=query,
|
|
key_cache=self.key_cache,
|
|
value_cache=self.value_cache,
|
|
num_kv_heads=self.num_kv_heads,
|
|
num_heads=self.num_heads,
|
|
scale_value=self.scale,
|
|
block_table=attn_metadata.block_tables,
|
|
context_lens=attn_metadata.seq_lens,
|
|
out=output,
|
|
)
|
|
return output
|
|
|
|
def _forward_encoder_attention(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attn_metadata: AscendMetadata,
|
|
_: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
assert attn_metadata is not None
|
|
|
|
if attn_metadata.causal:
|
|
# use sparse_mode 3 in causal scenario
|
|
return torch_npu.npu_fusion_attention(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
head_num=self.num_heads,
|
|
input_layout="TND",
|
|
scale=self.scale,
|
|
sparse_mode=3,
|
|
atten_mask=attn_metadata.attn_mask,
|
|
actual_seq_qlen=attn_metadata.actual_seq_lengths_q,
|
|
actual_seq_kvlen=attn_metadata.actual_seq_lengths_q,
|
|
)[0]
|
|
else:
|
|
# use default sparse_mode 0 in normal scenario, which means no mask works on it
|
|
return torch_npu.npu_fusion_attention(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
head_num=self.num_heads,
|
|
input_layout="TND",
|
|
scale=self.scale,
|
|
actual_seq_qlen=attn_metadata.actual_seq_lengths_q,
|
|
actual_seq_kvlen=attn_metadata.actual_seq_lengths_q,
|
|
)[0]
|
|
|
|
def reshape_and_cache(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: tuple[torch.Tensor],
|
|
attn_metadata: AscendMetadata,
|
|
output: torch.Tensor,
|
|
):
|
|
if len(kv_cache) > 1:
|
|
if self.is_kv_producer:
|
|
attn_metadata.reshape_cache_event = torch.npu.Event()
|
|
if self.key_cache is None:
|
|
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
|
slots = attn_metadata.slot_mapping
|
|
encoder_decoder = self.attn_type == AttentionType.ENCODER_DECODER
|
|
DeviceOperator.reshape_and_cache(
|
|
key=key[: attn_metadata.num_actual_tokens] if not encoder_decoder else key,
|
|
value=value[: attn_metadata.num_actual_tokens] if not encoder_decoder else value,
|
|
key_cache=self.key_cache,
|
|
value_cache=self.value_cache,
|
|
# quick fix to make sure slots is int32 for cross attention case.
|
|
# see: https://github.com/vllm-project/vllm/blob/ce88756b967c2c5006746a424c15dd59a284ed8c/vllm/model_executor/layers/attention/cross_attention.py#L117
|
|
slot_mapping=slots[: attn_metadata.num_actual_tokens] if not encoder_decoder else slots.to(torch.int32),
|
|
)
|
|
if self.is_kv_producer:
|
|
attn_metadata.reshape_cache_event.record()
|
|
return query, key, value, output
|
|
|
|
def forward_impl(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: tuple[torch.Tensor],
|
|
attn_metadata: AscendMetadata,
|
|
output: torch.Tensor,
|
|
):
|
|
num_tokens = query.shape[0]
|
|
if (
|
|
attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
|
and using_paged_attention(num_tokens, self.vllm_config)
|
|
and self.sliding_window is None
|
|
):
|
|
output = self.forward_paged_attention(query, attn_metadata, output)
|
|
else:
|
|
output = self.forward_fused_infer_attention(query, key, value, attn_metadata, output)
|
|
|
|
return output
|
|
|
|
def forward(
|
|
self,
|
|
layer: AttentionLayer,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: tuple[torch.Tensor],
|
|
attn_metadata: AscendMetadata,
|
|
output: torch.Tensor | None = None,
|
|
output_scale: torch.Tensor | None = None,
|
|
output_block_scale: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
"""Forward pass with Ascend attention.
|
|
Args:
|
|
query: shape = [num_tokens, num_heads, head_size]
|
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
|
kv_cache: shape =
|
|
[2, num_blocks, block_size, num_kv_heads, head_size]
|
|
attn_metadata: Metadata for attention.
|
|
Returns:
|
|
shape = [num_tokens, num_heads * head_size]
|
|
"""
|
|
assert output is not None, "Output tensor must be provided."
|
|
|
|
if output_scale is not None or output_block_scale is not None:
|
|
raise NotImplementedError("fused output quantization is not yet supported for AscendAttentionBackendImpl")
|
|
|
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
|
num_tokens = query.shape[0]
|
|
if attn_metadata is None:
|
|
return output.fill_(0)
|
|
output_padded = None
|
|
if key is not None and value is not None:
|
|
output_padded = output
|
|
query, key, value, output_padded = self.reshape_and_cache(
|
|
query, key, value, kv_cache, attn_metadata, output
|
|
)
|
|
# pooling model branch
|
|
if attn_metadata.model_runner_type == "pooling":
|
|
attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
|
|
output[:num_tokens] = attn_output[:num_tokens]
|
|
return output
|
|
if output_padded is not None:
|
|
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output_padded)
|
|
else:
|
|
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output)
|
|
output[:num_tokens] = attn_output[:num_tokens]
|
|
return output
|