# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project ############################################################################### # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company ############################################################################### from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type import torch import vllm_hpu_extension.kernels as kernels import vllm_hpu_extension.ops as ops from vllm_hpu_extension.flags import enabled_flags from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionMetadata, AttentionType, is_quantized_kv_cache) from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention, HPUPagedAttentionMetadata) from vllm.logger import init_logger logger = init_logger(__name__) class HPUAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: return "HPU_ATTN" @staticmethod def get_impl_cls() -> Type["HPUAttentionImpl"]: return HPUAttentionImpl @staticmethod def get_metadata_cls() -> Type["AttentionMetadata"]: return HPUAttentionMetadata @staticmethod def get_state_cls() -> Type["CommonAttentionState"]: return CommonAttentionState @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, src_to_dsts: torch.Tensor, ) -> None: HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dsts: torch.Tensor, ) -> None: HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts) @dataclass class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): """Metadata for HPUAttentionbackend.""" # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool attn_bias: Optional[torch.Tensor] seq_lens_tensor: Optional[torch.Tensor] context_lens_tensor: Optional[torch.Tensor] class HPUAttentionImpl(AttentionImpl, torch.nn.Module): """ If the input tensors contain prompt tokens, the layout is as follows: |<--------------- num_prefill_tokens ----------------->| |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| Otherwise, the layout is as follows: |<----------------- num_decode_tokens ------------------>| |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding. The prompts might have different lengths, while the generation tokens always have length 1. """ def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, max_seq_len: int = 4096, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, ) -> None: super(AttentionImpl, self).__init__() if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported in V0.") if use_irope: logger.warning_once( "Using irope in HPU is not supported yet, it will fall back " "to global attention for long context.") self.kv_cache_dtype = kv_cache_dtype self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.matmul_qk = Matmul() self.softmax = Softmax() self.matmul_av = Matmul() self.batch2block_matmul = Matmul() self.block2batch_matmul = Matmul() self.k_cache = VLLMKVCache() self.v_cache = VLLMKVCache() self.fused_scaled_dot_product_attention = kernels.fsdpa() self.prefill_impl = 'naive' if "flex_attention" in enabled_flags(): self.prefill_impl = 'flex' if "fsdpa" in enabled_flags(): assert alibi_slopes is None, \ 'Prefill with FusedSDPA not supported with alibi slopes!' self.prefill_impl = 'fsdpa' self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window self.alibi_slopes = alibi_slopes if alibi_slopes is not None: alibi_slopes_tensor = torch.tensor(alibi_slopes, dtype=torch.bfloat16) self.alibi_slopes = alibi_slopes_tensor self.num_queries_per_kv = self.num_heads // self.num_kv_heads if self.prefill_impl == 'fsdpa': assert alibi_slopes is None, \ 'Prefill with FusedSDPA not supported with alibi slopes!' supported_head_sizes = HPUPagedAttention.get_supported_head_sizes() if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {supported_head_sizes}.") self.attn_type = attn_type if self.attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " "HPUAttentionImpl") if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "HPUAttention with FP8 KV cache not yet supported") def forward( self, layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: HPUAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ if output_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for HPUAttentionImpl") batch_size, seq_len, hidden_size = query.shape _, seq_len_kv, _ = key.shape key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) block_indices = attn_metadata.block_indices block_offsets = attn_metadata.block_offsets key_cache = None value_cache = None if attn_metadata.is_prompt and self.attn_type \ is not AttentionType.ENCODER_ONLY: key = key.unflatten(0, (block_indices.size(0), -1)) value = value.unflatten(0, (block_indices.size(0), -1)) if kv_cache is not None and isinstance(kv_cache, tuple): key_cache, value_cache = HPUPagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. key_cache = self.k_cache(key, key_cache, block_indices, block_offsets) value_cache = self.v_cache(value, value_cache, block_indices, block_offsets) if attn_metadata.is_prompt: # Prompt run. query_shape = (batch_size, seq_len, self.num_heads, self.head_size) kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size) attn_bias = attn_metadata.attn_bias if attn_bias is not None and self.alibi_slopes is not None: position_bias = _make_alibi_bias(self.alibi_slopes, self.num_kv_heads, attn_bias.dtype, attn_bias.shape[-1]) attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) attn_bias.add_(position_bias) block_list = attn_metadata.block_list if attn_metadata \ and attn_metadata.block_list is not None else None out = ops.prompt_attention( impl=self.prefill_impl, query=query.view(query_shape), key=key.view(kv_shape), value=value.view(kv_shape), is_causal=True, attn_bias=attn_bias, valid_seq_lengths=attn_metadata.seq_lens_tensor, **self.common_attention_args(block_list, key_cache, value_cache)) output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. output = HPUPagedAttention.forward_decode( query=query, block_mapping=attn_metadata.block_mapping, block_bias=attn_metadata.attn_bias, block_groups=attn_metadata.block_groups, **self.common_attention_args(attn_metadata.block_list, key_cache, value_cache)) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) def common_attention_args(self, block_list=None, key_cache=None, value_cache=None): fsdpa_op = self.fused_scaled_dot_product_attention.apply \ if self.fused_scaled_dot_product_attention is not None else None return { 'scale': self.scale, 'matmul_qk_op': self.matmul_qk, 'matmul_av_op': self.matmul_av, 'batch2block_matmul_op': self.batch2block_matmul, 'block2batch_matmul_op': self.block2batch_matmul, 'fsdpa_op': fsdpa_op, 'keys_fetch_func': self.k_cache.fetch_from_cache, 'values_fetch_func': self.v_cache.fetch_from_cache, 'softmax_op': self.softmax, 'block_list': block_list, 'key_cache': key_cache, 'value_cache': value_cache, } def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, dtype: torch.dtype, seq_len: int, ) -> torch.Tensor: bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses # `bias = bias[None, :].repeat(seq_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. # Calculate a matrix where each element represents ith element- jth # element. bias = bias[None, :] - bias[:, None] padded_len = (seq_len + 7) // 8 * 8 num_heads = alibi_slopes.shape[0] bias = torch.empty( 1, # batch size num_heads, seq_len, padded_len, device=alibi_slopes.device, dtype=dtype, )[:, :, :, :seq_len].copy_(bias) bias.mul_(alibi_slopes[:, None, None]) if num_heads != num_kv_heads: bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) return bias