# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import List, Optional, Tuple try: import intel_extension_for_pytorch.llm.modules as ipex_modules _use_ipex = True # AttributeError is to handle a bug in ipex https://github.com/intel/intel-extension-for-pytorch/pull/813 except (ImportError, AttributeError): _use_ipex = False import torch from vllm import _custom_ops as ops class _PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: return [32, 64, 80, 96, 112, 128, 192, 256] @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, *args, ) -> Tuple[int, ...]: return 2, num_blocks, block_size * num_kv_heads * head_size @staticmethod def split_kv_cache( kv_cache: torch.Tensor, num_kv_heads: int, head_size: int, *args, ) -> Tuple[torch.Tensor, torch.Tensor]: x = 16 // kv_cache.element_size() num_blocks = kv_cache.shape[1] key_cache = kv_cache[0] key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x) value_cache = kv_cache[1] value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) return key_cache, value_cache @staticmethod def write_to_paged_cache( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, *args, ) -> None: ops.reshape_and_cache( key, value, key_cache, value_cache, slot_mapping.flatten(), kv_cache_dtype, k_scale, v_scale, ) @staticmethod def forward_decode( output: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, context_lens: torch.Tensor, max_context_len: int, kv_cache_dtype: str, num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], k_scale: torch.Tensor, v_scale: torch.Tensor, *args, ) -> None: tp_rank: int = 0 blocksparse_local_blocks: int = 0 blocksparse_vert_stride: int = 0 blocksparse_block_size: int = 64 blocksparse_head_sliding_step: int = 0 block_size = value_cache.shape[3] ops.paged_attention_v1( output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step, ) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, *args, ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] ops.copy_blocks(key_caches, value_caches, src_to_dists) class _IPEXPagedAttention(_PagedAttention): @staticmethod def split_kv_cache( kv_cache: torch.Tensor, num_kv_heads: int, head_size: int, *args, ) -> Tuple[torch.Tensor, torch.Tensor]: num_blocks = kv_cache.shape[1] key_cache = kv_cache[0] key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) value_cache = kv_cache[1] value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size) return key_cache, value_cache @staticmethod def write_to_paged_cache( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, *args, ) -> None: ipex_modules.PagedAttention.reshape_and_cache( key, value, key_cache, value_cache, slot_mapping.flatten().int()) @staticmethod def forward_decode( output: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, context_lens: torch.Tensor, max_context_len: int, kv_cache_dtype: str, num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], k_scale: torch.Tensor, v_scale: torch.Tensor, *args, ) -> None: block_size = value_cache.shape[2] head_mapping = torch.arange( 0, num_kv_heads, device="cpu", dtype=torch.int32, ).view(num_kv_heads, 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() ipex_modules.PagedAttention.single_query_cached_kv_attention( output, query.contiguous(), key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes) PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention