207 lines
6.9 KiB
Python
207 lines
6.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
|
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.attention.backends.abstract import AttentionType
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
QuantKey,
|
|
kFp8StaticTensorSym,
|
|
)
|
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
|
from vllm.v1.attention.backends.rocm_attn import (
|
|
RocmAttentionBackend,
|
|
RocmAttentionImpl,
|
|
RocmAttentionMetadataBuilder,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
|
accept_output_buffer: bool = True
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "ROCM_AITER_UNIFIED_ATTN"
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]:
|
|
return RocmAiterUnifiedAttentionImpl
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
cache_dtype_str: str = "auto",
|
|
) -> tuple[int, ...]:
|
|
if block_size % 16 != 0:
|
|
raise ValueError("Block size must be a multiple of 16.")
|
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
|
|
|
@staticmethod
|
|
def use_cascade_attention(*args, **kwargs) -> bool:
|
|
return False
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
|
|
return RocmAttentionMetadataBuilder
|
|
|
|
|
|
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
|
def fused_output_quant_supported(self, quant_key: QuantKey):
|
|
return quant_key == kFp8StaticTensorSym
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: list[float] | None,
|
|
sliding_window: int | None,
|
|
kv_cache_dtype: str,
|
|
logits_soft_cap: float | None = None,
|
|
attn_type: AttentionType = AttentionType.DECODER,
|
|
kv_sharing_target_layer_name: int | None = None,
|
|
sinks: torch.Tensor | None = None,
|
|
) -> None:
|
|
super().__init__(
|
|
num_heads,
|
|
head_size,
|
|
scale,
|
|
num_kv_heads,
|
|
alibi_slopes,
|
|
sliding_window,
|
|
kv_cache_dtype,
|
|
logits_soft_cap,
|
|
attn_type,
|
|
kv_sharing_target_layer_name,
|
|
sinks,
|
|
)
|
|
logger.info_once(
|
|
"Using aiter unified attention for RocmAiterUnifiedAttentionImpl"
|
|
)
|
|
from aiter.ops.triton.unified_attention import unified_attention
|
|
|
|
self.unified_attention = unified_attention
|
|
|
|
def forward(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: FlashAttentionMetadata,
|
|
output: torch.Tensor | None = None,
|
|
output_scale: torch.Tensor | None = None,
|
|
output_block_scale: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
"""Forward pass with FlashAttention.
|
|
|
|
Args:
|
|
query: shape = [num_tokens, num_heads, head_size]
|
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
|
kv_cache: shape =
|
|
[2, num_blocks, block_size, num_kv_heads, head_size]
|
|
attn_metadata: Metadata for attention.
|
|
Returns:
|
|
shape = [num_tokens, num_heads * head_size]
|
|
"""
|
|
assert output is not None, "Output tensor must be provided."
|
|
|
|
if output_block_scale is not None:
|
|
raise NotImplementedError(
|
|
"fused block_scale output quantization is not yet supported"
|
|
" for RocmAttentionImpl"
|
|
)
|
|
|
|
if attn_metadata is None:
|
|
# Profiling run.
|
|
return output.fill_(0)
|
|
|
|
assert attn_metadata.use_cascade is False
|
|
|
|
# IMPORTANT!
|
|
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
|
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
|
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
|
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
|
# Minimize the PyTorch ops in this method as much as possible.
|
|
# Whenever making a change in this method, please benchmark the
|
|
# performance to make sure it does not introduce any overhead.
|
|
|
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
|
|
|
key_cache, value_cache = kv_cache.unbind(0)
|
|
|
|
# key and value may be None in the case of cross attention. They are
|
|
# calculated once based on the output from the encoder and then cached
|
|
# in KV cache.
|
|
if (
|
|
self.kv_sharing_target_layer_name is None
|
|
and key is not None
|
|
and value is not None
|
|
):
|
|
# Reshape the input keys and values and store them in the cache.
|
|
# Skip this if sharing KV cache with an earlier attention layer.
|
|
ops.reshape_and_cache_flash(
|
|
key,
|
|
value,
|
|
key_cache,
|
|
value_cache,
|
|
attn_metadata.slot_mapping,
|
|
self.kv_cache_dtype,
|
|
layer._k_scale,
|
|
layer._v_scale,
|
|
)
|
|
|
|
if self.kv_cache_dtype.startswith("fp8"):
|
|
key_cache = key_cache.view(self.fp8_dtype)
|
|
value_cache = value_cache.view(self.fp8_dtype)
|
|
assert layer._q_scale_float == 1.0, (
|
|
"A non 1.0 q_scale is not currently supported."
|
|
)
|
|
|
|
cu_seqlens_q = attn_metadata.query_start_loc
|
|
seqused_k = attn_metadata.seq_lens
|
|
max_seqlen_q = attn_metadata.max_query_len
|
|
max_seqlen_k = attn_metadata.max_seq_len
|
|
block_table = attn_metadata.block_table
|
|
|
|
descale_shape = (
|
|
cu_seqlens_q.shape[0] - 1,
|
|
key.shape[1] if key is not None else self.num_kv_heads,
|
|
)
|
|
|
|
self.unified_attention(
|
|
q=query[:num_actual_tokens],
|
|
k=key_cache,
|
|
v=value_cache,
|
|
out=output[:num_actual_tokens],
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
max_seqlen_q=max_seqlen_q,
|
|
seqused_k=seqused_k,
|
|
max_seqlen_k=max_seqlen_k,
|
|
softmax_scale=self.scale,
|
|
causal=True,
|
|
alibi_slopes=self.alibi_slopes,
|
|
window_size=self.sliding_window,
|
|
block_table=block_table,
|
|
softcap=self.logits_soft_cap,
|
|
q_descale=None, # Not supported
|
|
k_descale=layer._k_scale.expand(descale_shape),
|
|
v_descale=layer._v_scale.expand(descale_shape),
|
|
sinks=self.sinks,
|
|
output_scale=output_scale,
|
|
)
|
|
|
|
return output
|