forked from EngineX-Hygon/enginex-hygon-vllm
init src 0.9.2
This commit is contained in:
195
vllm/attention/ops/ipex_attn.py
Normal file
195
vllm/attention/ops/ipex_attn.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
||||
_use_ipex = True
|
||||
# AttributeError is to handle a bug in ipex https://github.com/intel/intel-extension-for-pytorch/pull/813
|
||||
except (ImportError, AttributeError):
|
||||
_use_ipex = False
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
class _PagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 80, 96, 112, 128, 192, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
*args,
|
||||
) -> Tuple[int, ...]:
|
||||
return 2, num_blocks, block_size * num_kv_heads * head_size
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
*args,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
tp_rank: int = 0
|
||||
blocksparse_local_blocks: int = 0
|
||||
blocksparse_vert_stride: int = 0
|
||||
blocksparse_block_size: int = 64
|
||||
blocksparse_head_sliding_step: int = 0
|
||||
block_size = value_cache.shape[3]
|
||||
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
|
||||
class _IPEXPagedAttention(_PagedAttention):
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
*args,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
ipex_modules.PagedAttention.reshape_and_cache(
|
||||
key, value, key_cache, value_cache,
|
||||
slot_mapping.flatten().int())
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
block_size = value_cache.shape[2]
|
||||
head_mapping = torch.arange(
|
||||
0,
|
||||
num_kv_heads,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
).view(num_kv_heads,
|
||||
1).repeat_interleave(query.size(1) // num_kv_heads).flatten()
|
||||
ipex_modules.PagedAttention.single_query_cached_kv_attention(
|
||||
output, query.contiguous(), key_cache, value_cache, head_mapping,
|
||||
scale, block_tables, context_lens, block_size, max_context_len,
|
||||
alibi_slopes)
|
||||
|
||||
|
||||
PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention
|
||||
Reference in New Issue
Block a user