### What this PR does / why we need it?
When using the swa parameter in fia, headDim does not currently support
256, and when gemma3's headDim is equal to 256, an error will occur.
Therefore, code rollback is required, and it will be incorporated after
cann supports it.
### Does this PR introduce _any_ user-facing change?
Remove swa parameter of fia.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
7157596103
---------
Signed-off-by: nsdie <yeyifan@huawei.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
754 lines
31 KiB
Python
754 lines
31 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import ClassVar, List, Optional, Tuple, Type
|
|
|
|
import torch
|
|
import torch_npu
|
|
import vllm.envs as envs_vllm
|
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
AttentionLayer, AttentionType)
|
|
from vllm.attention.backends.registry import (AttentionBackendEnum,
|
|
register_backend)
|
|
from vllm.config import VllmConfig, get_current_vllm_config
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
|
from vllm.utils.math_utils import cdiv
|
|
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
|
AttentionMetadataBuilder)
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
|
|
from vllm_ascend.attention.context_parallel.common_cp import (
|
|
AscendMetadataForDecode, AscendMetadataForPrefill)
|
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
|
enable_cp, split_decodes_and_prefills,
|
|
using_paged_attention)
|
|
from vllm_ascend.compilation.acl_graph import (
|
|
get_draft_graph_params, get_graph_params,
|
|
update_draft_graph_params_workspaces, update_graph_params_workspaces)
|
|
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
|
weak_ref_tensors)
|
|
|
|
# default max value of sliding window size
|
|
SWA_INT_MAX = 2147483647
|
|
|
|
|
|
@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND")
|
|
class AscendAttentionBackend(AttentionBackend):
|
|
accept_output_buffer: bool = True
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
# HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
|
|
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
|
|
# rectify this when vllm disable the assertion.
|
|
return "CUSTOM" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
|
if enable_cp():
|
|
from vllm_ascend.attention.context_parallel.attention_cp import \
|
|
AscendAttentionCPImpl
|
|
return AscendAttentionCPImpl
|
|
return AscendAttentionBackendImpl
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
|
if enable_cp():
|
|
from vllm_ascend.attention.context_parallel.attention_cp import \
|
|
AscendAttentionCPMetadataBuilder
|
|
return AscendAttentionCPMetadataBuilder
|
|
return AscendAttentionMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
) -> Tuple[int, ...]:
|
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
|
|
|
@staticmethod
|
|
def swap_blocks(
|
|
src_kv_cache: List[torch.Tensor],
|
|
dst_kv_cache: List[torch.Tensor],
|
|
src_to_dst: torch.Tensor,
|
|
) -> None:
|
|
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
|
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
|
src_indices = src_to_dst[:, 0]
|
|
dst_indices = src_to_dst[:, 1]
|
|
|
|
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
|
dst_key_cache.device)
|
|
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
|
dst_key_cache.device)
|
|
|
|
@staticmethod
|
|
def copy_blocks(
|
|
kv_caches: List[torch.Tensor],
|
|
src_to_dists: torch.Tensor,
|
|
) -> None:
|
|
src_indices = src_to_dists[:, 0]
|
|
dst_indices = src_to_dists[:, 1]
|
|
|
|
for kv_cache in kv_caches:
|
|
key_caches = kv_cache[0]
|
|
value_caches = kv_cache[1]
|
|
key_caches[dst_indices] = key_caches[src_indices]
|
|
value_caches[dst_indices] = value_caches[src_indices]
|
|
|
|
@staticmethod
|
|
def get_supported_block_size() -> list[int]:
|
|
return [128]
|
|
|
|
|
|
class AscendAttentionState(Enum):
|
|
PrefillNoCache = 0
|
|
PrefillCacheHit = 1
|
|
DecodeOnly = 2
|
|
ChunkedPrefill = 3
|
|
SpecDecoding = 4
|
|
|
|
|
|
@dataclass
|
|
class AscendMetadata:
|
|
# **************************** Basic Properties ************************** #
|
|
attn_mask: Optional[torch.Tensor] = None
|
|
# Current state of this attention run.
|
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
|
|
|
# Number of tokens excluding padding.
|
|
num_actual_tokens_pcp_padded: int = 0
|
|
num_actual_tokens: int = 0
|
|
num_decode_tokens: int = 0
|
|
num_prefills: int = 0
|
|
num_decodes: int = 0
|
|
|
|
# The sequence length per sequence. Sequence length means the computed
|
|
# tokens + new tokens (is None if it is a decoding).
|
|
# (batch_size,)
|
|
# TODO(Angazenn): The following parameters are quite redundant and
|
|
# contains similar information (such as seq_lens seq_lens_list). We
|
|
# should simplified these parameters once attention schema in vLLM-Ascend
|
|
# is unified.
|
|
seq_lens: torch.Tensor = None
|
|
seq_lens_list: List[int] = None # type: ignore
|
|
actual_seq_lengths_q: List[int] = None # type: ignore
|
|
|
|
query_start_loc: torch.Tensor = None
|
|
# Maximum query length in the batch (None for decoding).
|
|
max_query_len: Optional[int] = None
|
|
|
|
# ********************** KV Cache Related Properties ********************* #
|
|
# Block addresses per sequence (Seq id -> list of physical block).
|
|
# (batch_size, max_blocks_per_seq)
|
|
block_tables: torch.Tensor = None
|
|
|
|
# The indices of the token slots that input tokens will be stored into.
|
|
# E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the
|
|
# three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0,
|
|
# and 1st slot in block 1, respectively.
|
|
# (num_tokens,)
|
|
slot_mapping: torch.Tensor = None
|
|
# pcp
|
|
prefill: Optional[AscendMetadataForPrefill] = None
|
|
# dcp
|
|
decode_meta: Optional[AscendMetadataForDecode] = None
|
|
|
|
causal: bool = True
|
|
# runner_type in model_config.
|
|
model_runner_type: str = ""
|
|
# prefill reshape_and_cache event
|
|
reshape_cache_event: torch.npu.Event = None
|
|
|
|
# sliding window attention mask
|
|
swa_mask: Optional[torch.Tensor] = None
|
|
|
|
|
|
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
|
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
|
# Does this backend/builder reorder the batch?
|
|
# If not, set this to None. Otherwise set it to the query
|
|
# length that will be pulled into the front of the batch.
|
|
reorder_batch_threshold: ClassVar[int] = 1
|
|
|
|
def __init__(
|
|
self,
|
|
kv_cache_spec: AttentionSpec,
|
|
layer_names: list[str],
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
):
|
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
self.compilation_config = vllm_config.compilation_config
|
|
self.device = device
|
|
self.max_num_blocks_per_req = cdiv(
|
|
self.model_config.max_model_len,
|
|
AscendAttentionBackend.get_supported_block_size()[0])
|
|
|
|
self.speculative_config = vllm_config.speculative_config
|
|
self.decode_threshold = 1
|
|
if self.speculative_config:
|
|
spec_token_num = self.speculative_config.num_speculative_tokens
|
|
self.decode_threshold += spec_token_num
|
|
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
|
|
npu_fused_infer_attention_score TND layout's limit of 16, \
|
|
got {self.decode_threshold}"
|
|
|
|
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
|
|
|
|
scheduler_config = vllm_config.scheduler_config
|
|
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
|
|
|
|
@classmethod
|
|
def get_cudagraph_support(
|
|
cls: type["AscendAttentionMetadataBuilder"],
|
|
vllm_config: VllmConfig,
|
|
kv_cache_spec: AttentionSpec,
|
|
) -> AttentionCGSupport:
|
|
# Explicit override in case the underlying builder specialized this getter.
|
|
# @override omitted only because of mypy limitation due to type variable.
|
|
return AttentionCGSupport.ALWAYS
|
|
|
|
def reorder_batch(self, input_batch,
|
|
scheduler_output: "SchedulerOutput") -> bool:
|
|
return False
|
|
|
|
def build(
|
|
self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
|
fast_build: bool = False,
|
|
) -> AscendMetadata:
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
|
num_reqs
|
|
+ 1]
|
|
|
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
|
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
|
|
|
block_table = common_attn_metadata.block_table_tensor
|
|
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
|
|
|
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
|
attn_mask = common_attn_metadata.attn_mask
|
|
swa_mask = common_attn_metadata.swa_mask
|
|
attn_state = common_attn_metadata.attn_state
|
|
|
|
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
|
|
query_start_loc = query_start_loc_cpu.pin_memory().to(
|
|
self.device, non_blocking=True)
|
|
|
|
attn_metadata = AscendMetadata(
|
|
num_actual_tokens=num_actual_tokens,
|
|
num_decode_tokens=num_decode_tokens,
|
|
block_tables=block_table,
|
|
query_start_loc=query_start_loc,
|
|
seq_lens=seq_lens,
|
|
seq_lens_list=seq_lens.tolist(),
|
|
max_query_len=common_attn_metadata.max_query_len,
|
|
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
|
slot_mapping=slot_mapping,
|
|
attn_mask=attn_mask,
|
|
swa_mask=swa_mask,
|
|
attn_state=attn_state,
|
|
num_prefills=num_prefills,
|
|
num_decodes=num_decodes,
|
|
causal=common_attn_metadata.causal,
|
|
model_runner_type=self.model_config.runner_type)
|
|
return attn_metadata
|
|
|
|
def build_for_graph_capture(
|
|
self,
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
|
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
|
):
|
|
|
|
if attn_state in (AscendAttentionState.DecodeOnly,
|
|
AscendAttentionState.ChunkedPrefill):
|
|
attn_metadata = self.build(
|
|
common_prefix_len=0,
|
|
common_attn_metadata=common_attn_metadata,
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
"Currently we only support building dummy metadata for DecodeOnly and ChunkedPrefill state"
|
|
)
|
|
|
|
attn_metadata.attn_state = attn_state
|
|
return attn_metadata
|
|
|
|
|
|
class AscendAttentionBackendImpl(AttentionImpl):
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: Optional[List[float]],
|
|
sliding_window: Optional[int],
|
|
kv_cache_dtype: str,
|
|
logits_soft_cap: Optional[float],
|
|
attn_type: str,
|
|
kv_sharing_target_layer_name: Optional[str],
|
|
**kwargs,
|
|
) -> None:
|
|
self.vllm_config = get_current_vllm_config()
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = float(scale)
|
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
self.hidden_size = self.num_heads * self.head_size
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
self.sliding_window = sliding_window
|
|
if alibi_slopes is not None:
|
|
alibi_slopes = torch.tensor(alibi_slopes,
|
|
dtype=torch.float32,
|
|
device="npu")
|
|
self.alibi_slopes = alibi_slopes
|
|
self.attn_type = attn_type
|
|
|
|
assert self.num_heads % self.num_kv_heads == 0
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
self.key_cache = None
|
|
self.value_cache = None
|
|
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
|
|
|
def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
|
|
value: torch.Tensor, attn_metadata: AscendMetadata,
|
|
output: torch.Tensor) -> torch.Tensor:
|
|
key, value, block_size, block_table, actual_seq_lengths_kv \
|
|
= self._get_fia_params(key, value, attn_metadata)
|
|
|
|
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
|
forward_context = get_forward_context()
|
|
if forward_context.is_draft_model:
|
|
graph_params = get_draft_graph_params()
|
|
else:
|
|
graph_params = get_graph_params()
|
|
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q
|
|
# Prepare tensors for attention output
|
|
# TODO: Refactor this to step-level instead of layer-level
|
|
|
|
# Get workspace from cache or calculate it if not present.
|
|
workspace = graph_params.workspaces.get(num_tokens)
|
|
softmax_lse = torch.empty(1, dtype=query.dtype, device=query.device)
|
|
if workspace is None:
|
|
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
atten_mask=attn_metadata.attn_mask,
|
|
block_table=block_table,
|
|
input_layout="TND",
|
|
block_size=block_size,
|
|
actual_seq_lengths=actual_seq_lengths_q,
|
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
|
num_key_value_heads=self.num_kv_heads,
|
|
num_heads=self.num_heads,
|
|
sparse_mode=3,
|
|
scale=self.scale,
|
|
)
|
|
if forward_context.is_draft_model:
|
|
update_draft_graph_params_workspaces(num_tokens, workspace)
|
|
else:
|
|
update_graph_params_workspaces(num_tokens, workspace)
|
|
|
|
# Handle graph capturing mode
|
|
stream = torch_npu.npu.current_stream()
|
|
|
|
event = torch.npu.ExternalEvent()
|
|
event.wait(stream)
|
|
event.reset(stream)
|
|
graph_params.events[num_tokens].append(event)
|
|
graph_params.attn_params[num_tokens].append(
|
|
(weak_ref_tensors(query), weak_ref_tensors(key),
|
|
weak_ref_tensors(value), weak_ref_tensors(block_table),
|
|
weak_ref_tensors(attn_metadata.attn_mask), block_size,
|
|
actual_seq_lengths_kv, actual_seq_lengths_q, self.num_kv_heads,
|
|
self.num_heads, self.scale, weak_ref_tensors(output),
|
|
weak_ref_tensors(softmax_lse)))
|
|
|
|
torch.npu.graph_task_group_begin(stream)
|
|
torch_npu.npu_fused_infer_attention_score.out(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
atten_mask=attn_metadata.attn_mask,
|
|
block_table=block_table,
|
|
input_layout="TND",
|
|
block_size=block_size,
|
|
actual_seq_lengths=actual_seq_lengths_q,
|
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
|
num_key_value_heads=self.num_kv_heads,
|
|
num_heads=self.num_heads,
|
|
scale=self.scale,
|
|
sparse_mode=3,
|
|
workspace=workspace,
|
|
out=[output, softmax_lse],
|
|
)
|
|
|
|
output = output.view(num_tokens, self.num_heads, self.head_size)
|
|
|
|
handle = torch.npu.graph_task_group_end(stream)
|
|
graph_params.handles[num_tokens].append(handle)
|
|
return output, num_tokens
|
|
|
|
def full_graph_pa(
|
|
self,
|
|
query: torch.Tensor,
|
|
attn_metadata: AscendMetadata,
|
|
output: Optional[torch.Tensor] = None,
|
|
):
|
|
graph_params = get_graph_params()
|
|
forward_context: ForwardContext = get_forward_context()
|
|
num_tokens = query.shape[0]
|
|
if forward_context.capturing:
|
|
# Get workspace from cache or calculate it if not present.
|
|
workspace = graph_params.workspaces.get(num_tokens)
|
|
if workspace is None:
|
|
workspace = torch_npu._npu_paged_attention_get_workspace(
|
|
query=query,
|
|
key_cache=self.key_cache,
|
|
value_cache=self.value_cache,
|
|
num_kv_heads=self.num_kv_heads,
|
|
num_heads=self.num_heads,
|
|
scale_value=self.scale,
|
|
block_table=attn_metadata.block_tables,
|
|
context_lens=attn_metadata.seq_lens,
|
|
out=output)
|
|
update_graph_params_workspaces(num_tokens, workspace)
|
|
|
|
# Handle graph capturing mode
|
|
stream = torch_npu.npu.current_stream()
|
|
|
|
event = torch.npu.ExternalEvent()
|
|
event.wait(stream)
|
|
event.reset(stream)
|
|
graph_params.events[num_tokens].append(event)
|
|
graph_params.attn_params[num_tokens].append((
|
|
weak_ref_tensors(query),
|
|
weak_ref_tensors(self.key_cache),
|
|
weak_ref_tensors(self.value_cache),
|
|
self.num_kv_heads,
|
|
self.num_heads,
|
|
self.scale,
|
|
attn_metadata.block_tables,
|
|
attn_metadata.seq_lens,
|
|
weak_ref_tensors(output),
|
|
))
|
|
|
|
torch.npu.graph_task_group_begin(stream)
|
|
torch_npu._npu_paged_attention(
|
|
query=query,
|
|
key_cache=self.key_cache,
|
|
value_cache=self.value_cache,
|
|
num_kv_heads=self.num_kv_heads,
|
|
num_heads=self.num_heads,
|
|
scale_value=self.scale,
|
|
block_table=attn_metadata.block_tables,
|
|
context_lens=attn_metadata.seq_lens,
|
|
out=output,
|
|
workspace=workspace)
|
|
handle = torch.npu.graph_task_group_end(stream)
|
|
graph_params.handles[num_tokens].append(handle)
|
|
return output
|
|
|
|
def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor,
|
|
attn_metadata: AscendMetadata):
|
|
|
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
|
block_size = 128
|
|
block_table = None
|
|
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
|
|
elif attn_metadata.attn_state == \
|
|
AscendAttentionState.PrefillCacheHit:
|
|
batch_size = attn_metadata.seq_lens.shape[0]
|
|
block_table = attn_metadata.block_tables[:batch_size, :]
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
key = self.key_cache.view( # type: ignore
|
|
num_block, block_size, -1)
|
|
value = self.value_cache.view( # type: ignore
|
|
num_block, block_size, -1)
|
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
key = self.key_cache.view( # type: ignore
|
|
num_block, block_size, -1)
|
|
value = self.value_cache.view( # type: ignore
|
|
num_block, block_size, -1)
|
|
block_table = attn_metadata.block_tables
|
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
# chunked prefill.
|
|
else:
|
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
|
key = self.key_cache.view( # type: ignore
|
|
num_block, block_size, -1)
|
|
value = self.value_cache.view( # type: ignore
|
|
num_block, block_size, -1)
|
|
block_table = attn_metadata.block_tables
|
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
|
return key, value, block_size, block_table, actual_seq_lengths_kv
|
|
|
|
def _forward_fia_slidingwindow(self, query: torch.Tensor,
|
|
attn_metadata: AscendMetadata,
|
|
output: torch.Tensor):
|
|
batch_size = attn_metadata.seq_lens.shape[0]
|
|
block_size = 128
|
|
query = query.view(batch_size, 1, self.num_heads * self.head_size)
|
|
key = self.key_cache
|
|
value = self.value_cache
|
|
if self.key_cache is not None and self.value_cache is not None:
|
|
block_size = self.key_cache.shape[1]
|
|
key = self.key_cache.flatten(2, 3).contiguous()
|
|
value = self.value_cache.flatten(2, 3).contiguous()
|
|
|
|
output, _ = torch_npu.npu_fused_infer_attention_score(
|
|
query,
|
|
key,
|
|
value,
|
|
num_heads=self.num_heads,
|
|
num_key_value_heads=self.num_kv_heads,
|
|
input_layout="BSH",
|
|
block_size=block_size,
|
|
pre_tokens=self.sliding_window,
|
|
scale=self.scale,
|
|
block_table=attn_metadata.block_tables,
|
|
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
|
|
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
|
|
|
output = output.view(batch_size, self.num_heads, self.head_size)
|
|
return output
|
|
|
|
def forward_fused_infer_attention(self, query: torch.Tensor,
|
|
key: torch.Tensor, value: torch.Tensor,
|
|
attn_metadata: AscendMetadata,
|
|
output: torch.Tensor):
|
|
forward_context: ForwardContext = get_forward_context()
|
|
# we inherit ForwardContext in model runner v2, when enable model
|
|
# runner v2, there is not capturing attribute in forward_context,
|
|
# just use getattr to avoid attribute error.
|
|
if getattr(forward_context, "capturing", False):
|
|
attn_output, num_tokens = self.full_graph_fia(
|
|
query, key, value, attn_metadata, output)
|
|
output[:num_tokens] = attn_output[:num_tokens]
|
|
return output
|
|
if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
|
and self.sliding_window is not None
|
|
and attn_metadata.seq_lens.shape[0] == query.size(0)):
|
|
return self._forward_fia_slidingwindow(query, attn_metadata,
|
|
output)
|
|
key, value, block_size, block_table, actual_seq_lengths_kv \
|
|
= self._get_fia_params(key, value, attn_metadata)
|
|
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
|
query = query[:num_tokens]
|
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
|
key = key[:num_tokens]
|
|
value = value[:num_tokens]
|
|
# Get workspace from cache or calculate it if not present.
|
|
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
atten_mask=attn_metadata.attn_mask,
|
|
block_table=block_table,
|
|
input_layout="TND",
|
|
block_size=block_size,
|
|
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
|
num_key_value_heads=self.num_kv_heads,
|
|
num_heads=self.num_heads,
|
|
scale=self.scale,
|
|
sparse_mode=3,
|
|
)
|
|
|
|
attn_output = attn_output.view(num_tokens, self.num_heads,
|
|
self.head_size)
|
|
output[:num_tokens] = attn_output[:num_tokens]
|
|
return output
|
|
|
|
def forward_paged_attention(
|
|
self,
|
|
query: torch.Tensor,
|
|
attn_metadata: AscendMetadata,
|
|
output: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
forward_context: ForwardContext = get_forward_context()
|
|
if forward_context.capturing:
|
|
return self.full_graph_pa(query, attn_metadata, output)
|
|
torch_npu._npu_paged_attention(query=query,
|
|
key_cache=self.key_cache,
|
|
value_cache=self.value_cache,
|
|
num_kv_heads=self.num_kv_heads,
|
|
num_heads=self.num_heads,
|
|
scale_value=self.scale,
|
|
block_table=attn_metadata.block_tables,
|
|
context_lens=attn_metadata.seq_lens,
|
|
out=output)
|
|
return output
|
|
|
|
def _forward_encoder_attention(self, query: torch.Tensor,
|
|
key: torch.Tensor, value: torch.Tensor,
|
|
attn_metadata: AscendMetadata,
|
|
_: torch.Tensor) -> torch.Tensor:
|
|
assert attn_metadata is not None
|
|
|
|
if attn_metadata.causal:
|
|
# use sparse_mode 3 in causal scenario
|
|
return torch_npu.npu_fusion_attention(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
head_num=self.num_heads,
|
|
input_layout="TND",
|
|
scale=self.scale,
|
|
sparse_mode=3,
|
|
atten_mask=attn_metadata.attn_mask,
|
|
actual_seq_qlen=attn_metadata.actual_seq_lengths_q,
|
|
actual_seq_kvlen=attn_metadata.actual_seq_lengths_q,
|
|
)[0]
|
|
else:
|
|
# use default sparse_mode 0 in normal scenario, which means no mask works on it
|
|
return torch_npu.npu_fusion_attention(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
head_num=self.num_heads,
|
|
input_layout="TND",
|
|
scale=self.scale,
|
|
actual_seq_qlen=attn_metadata.actual_seq_lengths_q,
|
|
actual_seq_kvlen=attn_metadata.actual_seq_lengths_q,
|
|
)[0]
|
|
|
|
def reshape_and_cache(
|
|
self,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: Tuple[torch.Tensor],
|
|
attn_metadata: AscendMetadata,
|
|
):
|
|
|
|
if len(kv_cache) > 1:
|
|
if self.is_kv_producer:
|
|
attn_metadata.reshape_cache_event = torch.npu.Event()
|
|
if self.key_cache is None:
|
|
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
|
slots = attn_metadata.slot_mapping
|
|
if get_ascend_device_type() == AscendDeviceType.A5:
|
|
# TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping.
|
|
# Should check if the 0 dim of slot_mapping must equal to the 0 dim of key.
|
|
# If it's necessary, the slots should be sliced.
|
|
torch_npu.npu_scatter_pa_kv_cache(
|
|
key=key[:attn_metadata.num_actual_tokens],
|
|
value=value[:attn_metadata.num_actual_tokens].contiguous(),
|
|
key_cache=self.key_cache,
|
|
value_cache=self.value_cache,
|
|
slot_mapping=slots)
|
|
else:
|
|
torch_npu._npu_reshape_and_cache(
|
|
key=key[:attn_metadata.num_actual_tokens],
|
|
value=value[:attn_metadata.num_actual_tokens],
|
|
key_cache=self.key_cache,
|
|
value_cache=self.value_cache,
|
|
slot_indices=slots[:attn_metadata.num_actual_tokens])
|
|
if self.is_kv_producer:
|
|
attn_metadata.reshape_cache_event.record()
|
|
return key, value
|
|
|
|
def forward_impl(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: Tuple[torch.Tensor],
|
|
attn_metadata: AscendMetadata,
|
|
output: torch.Tensor,
|
|
):
|
|
num_tokens = query.shape[0]
|
|
if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
|
and using_paged_attention(num_tokens, self.vllm_config)
|
|
and self.sliding_window is None):
|
|
output = self.forward_paged_attention(query, attn_metadata, output)
|
|
else:
|
|
output = self.forward_fused_infer_attention(
|
|
query, key, value, attn_metadata, output)
|
|
|
|
return output
|
|
|
|
def forward(
|
|
self,
|
|
layer: AttentionLayer,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: Tuple[torch.Tensor],
|
|
attn_metadata: AscendMetadata,
|
|
output: Optional[torch.Tensor] = None,
|
|
output_scale: Optional[torch.Tensor] = None,
|
|
output_block_scale: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Forward pass with Ascend attention.
|
|
Args:
|
|
query: shape = [num_tokens, num_heads, head_size]
|
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
|
kv_cache: shape =
|
|
[2, num_blocks, block_size, num_kv_heads, head_size]
|
|
attn_metadata: Metadata for attention.
|
|
Returns:
|
|
shape = [num_tokens, num_heads * head_size]
|
|
"""
|
|
assert output is not None, "Output tensor must be provided."
|
|
|
|
if output_scale is not None or output_block_scale is not None:
|
|
raise NotImplementedError(
|
|
"fused output quantization is not yet supported"
|
|
" for AscendAttentionBackendImpl")
|
|
|
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
|
attn_type = self.attn_type
|
|
if attn_type not in [
|
|
AttentionType.DECODER, AttentionType.ENCODER_ONLY
|
|
]:
|
|
raise NotImplementedError("Encoder/Decoder cross-attention "
|
|
"is not implemented for "
|
|
"PallasAttentionBackendImpl")
|
|
num_tokens = query.shape[0]
|
|
if attn_metadata is None:
|
|
return output.fill_(0)
|
|
key, value = self.reshape_and_cache(key, value, kv_cache,
|
|
attn_metadata)
|
|
# pooling model branch
|
|
if attn_metadata.model_runner_type == "pooling":
|
|
attn_output = self._forward_encoder_attention(
|
|
query, key, value, attn_metadata, output)
|
|
output[:num_tokens] = attn_output[:num_tokens]
|
|
return output
|
|
output = self.forward_impl(query, key, value, kv_cache, attn_metadata,
|
|
output)
|
|
return output
|